In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
# %load_ext cudf.pandas
# import pandas as pd
# print(pd)


In [None]:
# Optimizations
# GDAL optimizations
import multiprocessing as mp
import os

cpu_count: int = mp.cpu_count()
num_cores: int = cpu_count - 2
os.environ["GDAL_NUM_THREADS"] = f"{num_cores}"
os.environ["GDAL_CACHEMAX"] = "1024"


## Libraries

In [None]:
# Imports
import os
from pathlib import Path

import numpy as np
import pandas as pd
import riskmapjnr as rmj
from typing import Literal


In [None]:
# Add root to path
import sys

sys.path.append("..")
from component.script.utilities.file_filter import (
    list_files_by_extension,
    filter_files_by_keywords,
)


## Set user parameters

In [None]:
project_name = "test"


In [None]:
years: list[int] = [2015, 2020, 2024]
forest_source: Literal["gfc", "tmf"] = "gfc"
tree_cover_threshold: int = 10


In [None]:
static_variables = ["altitude", "slope", "pa", "subj", "dist_rivers", "dist_roads"]
dynamic_variables = ["forest", "deforestation", "forest_edge", "dist_towns"]


In [None]:
win_size_list = [5, 11, 21]
block_rows = 256


## Connect folders

In [None]:
root_folder: Path = Path.cwd().parent
downloads_folder: Path = root_folder / "data"
downloads_folder.mkdir(parents=True, exist_ok=True)


In [None]:
project_folder = downloads_folder / project_name
project_folder.mkdir(parents=True, exist_ok=True)
processed_data_folder = project_folder / "data"
processed_data_folder.mkdir(parents=True, exist_ok=True)
plots_folder = project_folder / "plots"
plots_folder.mkdir(parents=True, exist_ok=True)
rmj_mw = project_folder / "rmj_mw"
rmj_mw.mkdir(parents=True, exist_ok=True)


## Forest Files

In [None]:
# List all raster files in the processed data folder
input_raster_files = list_files_by_extension(processed_data_folder, [".tiff", ".tif"])
forest_change_file = filter_files_by_keywords(
    input_raster_files,
    ["defostack", forest_source],
    False,
    ["distance", "edge"],
)[0]
forest_yearly_files = filter_files_by_keywords(
    input_raster_files, ["forest", forest_source], False, ["loss", "edge"]
)
forest_edge_files = filter_files_by_keywords(
    input_raster_files, ["forest", forest_source, "edge"], False
)


In [None]:
raster_subj_file = filter_files_by_keywords(input_raster_files, ["subj"])[0]


## Periods dictionaries

In [None]:
import re
from pathlib import Path


def create_full_period_dict(
    years: list[int],
    period: str,
    processed_data_folder: Path,
    static_variables: list[str],
    dynamic_variables: list[str],
):
    """
    Create a comprehensive dictionary for a given modeling period.
    Handles period-independent and multi-temporal variables separately.
    """

    if len(years) < 3:
        raise ValueError("The 'years' list must contain at least three elements.")

    configs = {
        "calibration": {
            "train_period": "calibration",
            "initial_idx": 0,
            "final_idx": 1,
            "defor_value": 1,
            "var_idx": 0,
        },
        "validation": {
            "train_period": "calibration",
            "initial_idx": 1,
            "final_idx": 2,
            "defor_value": 1,
            "var_idx": 1,
        },
        "historical": {
            "train_period": "historical",
            "initial_idx": 0,
            "final_idx": 2,
            "defor_value": [1, 2],
            "var_idx": 0,
        },
        "forecast": {
            "train_period": "historical",
            "initial_idx": 0,
            "final_idx": 2,
            "defor_value": [1, 2],
            "var_idx": 2,
        },
    }

    if period not in configs:
        raise ValueError(f"Unknown period '{period}'. Must be one of: {list(configs)}.")

    c = configs[period]

    # --- Base period dictionary ---
    period_dict = {
        "period": period,
        "train_period": c["train_period"],
        "initial_year": years[c["initial_idx"]],
        "final_year": years[c["final_idx"]],
        "defor_value": c["defor_value"],
        "time_interval": years[c["final_idx"]] - years[c["initial_idx"]],
        "var_year": years[c["var_idx"]],
    }

    initial_year = str(period_dict["initial_year"])
    final_year = str(period_dict["final_year"])
    var_year = str(period_dict["var_year"])
    exclude_years = ", ".join(map(str, set(years) - {initial_year, final_year}))
    period_name = str(period_dict["period"])

    variable_file_mapping = {"period": period}
    input_raster_files = list_files_by_extension(
        processed_data_folder, [".tiff", ".tif"]
    )

    # --- Modular file search ---
    def _is_token_separate_in_name(token: str, name: str) -> bool:
        """
        Devuelve True si `token` aparece en `name` como 'palabra' separada por
        caracteres no alfanuméricos o en los límites (comportamiento similar a \b,
        pero \b considera "_" como no palabra; aquí queremos lo mismo).
        """
        if token.isdigit():  # años u otros números: buscar la secuencia directamente
            return token in name
        # construimos regex que asegura token no está pegado a letras o números
        pattern = rf"(?<![0-9A-Za-z]){re.escape(token)}(?![0-9A-Za-z])"
        return re.search(pattern, name) is not None

    def _strict_candidate_filter(candidates, tokens):
        """
        Filtra candidatos manteniendo sólo aquellos que contienen todos los tokens
        como 'palabras' separadas (ver _is_token_separate_in_name).
        """
        filtered = []
        for p in candidates:
            s = str(p).lower()
            if all(_is_token_separate_in_name(tok.lower(), s) for tok in tokens):
                filtered.append(p)
        return filtered

    def find_file(var_name, dynamic=False):
        """
        Busca un archivo que contenga los términos relevantes.
        Si es dinámico, incluye los años del periodo.
        """
        parts = var_name.split("_")
        include_terms = []
        if len(parts) == 1:
            exclude_terms = ["distance", "edge"]
        else:
            exclude_terms = None

        if dynamic:
            if period_name != "forecast":
                # Buscar archivos que incluyan los años del periodo
                if "deforestation" in parts:
                    include_terms = [*parts, initial_year, final_year]
                else:
                    include_terms = [*parts, initial_year]
            elif period_name == "forecast":
                include_terms = [*parts, var_year]
        else:
            include_terms = parts

        # Buscar distancias o bordes si el nombre lo indica
        if "dist" in parts and "distance" not in include_terms:
            include_terms.append("distance")

        files = filter_files_by_keywords(
            input_raster_files, include_terms, False, exclude_terms, True
        )
        # Si no hay archivos, devolvemos None
        if not files and period_name == "forecast":
            include_terms = [*parts, str(years[1])]
            files = filter_files_by_keywords(
                input_raster_files, include_terms, False, exclude_terms, True
            )
        if not files:
            return None

        # Si viene solo 1, ok
        if len(files) == 1:
            return files[0]
        strict = _strict_candidate_filter(files, parts)
        if strict:
            # si hay múltiplos aún, devolvemos el primero (heurística)
            return strict[0]

    # --- Buscar variables independientes ---
    for var in static_variables:
        variable_file_mapping[var] = find_file(var, dynamic=False)

    # --- Buscar variables multitemporales ---
    for var in dynamic_variables:
        variable_file_mapping[var] = find_file(var, dynamic=True)

    # --- Merge final ---
    period_dict.update(variable_file_mapping)
    return period_dict


In [None]:
calibration_dict = create_full_period_dict(
    years,
    "calibration",
    processed_data_folder,
    static_variables,
    dynamic_variables,
)
validation_dict = create_full_period_dict(
    years,
    "validation",
    processed_data_folder,
    static_variables,
    dynamic_variables,
)
historical_dict = create_full_period_dict(
    years,
    "historical",
    processed_data_folder,
    static_variables,
    dynamic_variables,
)
forecast_dict = create_full_period_dict(
    years,
    "forecast",
    processed_data_folder,
    static_variables,
    dynamic_variables,
)


## 1 Calculate distance to forest edge

In [None]:
deforestation_thresh = 99.5
max_dist = 5000


In [None]:
def calculate_period_dist_edge_threshold(
    forest_change_file,
    period_dictionary,
    deforestation_thresh,
    max_dist,
    model_folder,
    plots_folder,
):
    period_output_folder = model_folder / period_dictionary["period"]
    period_output_folder.mkdir(parents=True, exist_ok=True)

    dist_thresh = rmj.dist_edge_threshold(
        fcc_file=forest_change_file,
        defor_values=period_dictionary["defor_value"],
        defor_threshold=deforestation_thresh,
        dist_file=period_dictionary["forest_edge"],
        dist_bins=np.arange(0, max_dist, step=30),
        tab_file_dist=period_output_folder / "tab_dist.csv",
        # fig_file_dist= period_output_folder /"perc_dist.png",
        fig_file_dist=plots_folder / f"perc_dist_{period_dictionary['period']}.png",
        blk_rows=block_rows,
        dist_file_available=True,
        check_fcc=False,
        verbose=False,
    )
    # Save result
    dist_edge_data = pd.DataFrame(dist_thresh, index=[0])
    dist_edge_data.to_csv(
        period_output_folder / "dist_edge_threshold.csv",
        sep=",",
        header=True,
        index=False,
        index_label=False,
    )


In [None]:
calculate_period_dist_edge_threshold(
    forest_change_file,
    calibration_dict,
    deforestation_thresh,
    max_dist,
    rmj_mw,
    plots_folder,
)


In [None]:
calculate_period_dist_edge_threshold(
    forest_change_file,
    historical_dict,
    deforestation_thresh,
    max_dist,
    rmj_mw,
    plots_folder,
)


## 2 Compute local_defor_rate

In [None]:
def calculate_local_defor_rate(
    forest_change_file, win_size_list, period_dictionary, model_folder
):
    period_output_folder = model_folder / period_dictionary["period"]
    period_output_folder.mkdir(parents=True, exist_ok=True)
    trained_period_output_folder = model_folder / period_dictionary["train_period"]
    trained_period_output_folder.mkdir(parents=True, exist_ok=True)

    for win_size in win_size_list:
        ldefrate_file = period_output_folder / f"ldefrate_mw_{win_size}.tif"
        rmj.local_defor_rate(
            fcc_file=forest_change_file,
            defor_values=period_dictionary["defor_value"],
            ldefrate_file=ldefrate_file,
            win_size=win_size,
            time_interval=period_dictionary["time_interval"],
            rescale_min_val=2,
            rescale_max_val=65535,
            blk_rows=block_rows,
            verbose=False,
        )


In [None]:
calculate_local_defor_rate(forest_change_file, win_size_list, calibration_dict, rmj_mw)


In [None]:
calculate_local_defor_rate(forest_change_file, win_size_list, historical_dict, rmj_mw)


## 3 Compute prediction

In [None]:
def get_dist_thresh(ifile):
    """Get distance to forest edge threshold."""
    dist_thresh_data = pd.read_csv(ifile)
    dist_thresh = dist_thresh_data.loc[0, "dist_thresh"]
    return dist_thresh


def calculate_period_mw_prediction(period_dictionary, win_size_list, model_folder):
    period_output_folder = model_folder / period_dictionary["period"]
    period_output_folder.mkdir(parents=True, exist_ok=True)

    trained_period_output_folder = model_folder / period_dictionary["train_period"]
    period_dist_edge_file = trained_period_output_folder / "dist_edge_threshold.csv"

    for win_size in win_size_list:
        ldefrate_file = trained_period_output_folder / f"ldefrate_mw_{win_size}.tif"
        output_file = (
            period_output_folder
            / f"prob_mw_{win_size}_{period_dictionary['period']}.tif"
        )
        rmj.set_defor_cat_zero(
            ldefrate_file=ldefrate_file,
            dist_file=period_dictionary["forest_edge"],
            dist_thresh=get_dist_thresh(period_dist_edge_file),
            ldefrate_with_zero_file=output_file,
            blk_rows=block_rows,
            verbose=False,
        )


In [None]:
calculate_period_mw_prediction(calibration_dict, win_size_list, rmj_mw)


In [None]:
calculate_period_mw_prediction(historical_dict, win_size_list, rmj_mw)


In [None]:
calculate_period_mw_prediction(validation_dict, win_size_list, rmj_mw)


In [None]:
calculate_period_mw_prediction(forecast_dict, win_size_list, rmj_mw)


## 4 Compute deforestation rate per class

In [None]:
def calculate_period_mw_def_rate(
    forest_change_file, period_dictionary, win_size_list, model_folder
):
    period_output_folder = model_folder / period_dictionary["period"]
    if not os.path.exists(period_output_folder):
        os.makedirs(period_output_folder)
    trained_period_output_folder = model_folder / period_dictionary["train_period"]
    period_dist_edge_file = trained_period_output_folder / "dist_edge_threshold.csv"

    for win_size in win_size_list:
        ldefrate_file = trained_period_output_folder / f"ldefrate_mw_{win_size}.tif"
        riskmap_file = (
            period_output_folder
            / f"prob_mw_{win_size}_{period_dictionary['period']}.tif"
        )
        output_file = (
            period_output_folder
            / f"defrate_cat_mw_{win_size}_{period_dictionary['period']}.csv"
        )

        rmj.defrate_per_cat(
            fcc_file=forest_change_file,
            riskmap_file=riskmap_file,
            time_interval=period_dictionary["time_interval"],
            period=period_dictionary["period"],
            tab_file_defrate=output_file,
            blk_rows=block_rows,
            verbose=False,
        )


In [None]:
calculate_period_mw_def_rate(
    forest_change_file, calibration_dict, win_size_list, rmj_mw
)


In [None]:
calculate_period_mw_def_rate(forest_change_file, historical_dict, win_size_list, rmj_mw)


In [None]:
calculate_period_mw_def_rate(forest_change_file, validation_dict, win_size_list, rmj_mw)


In [None]:
calculate_period_mw_def_rate(forest_change_file, forecast_dict, win_size_list, rmj_mw)


## 5 Plot Risk map of deforestation

In [None]:
input_vector_files = list_files_by_extension(processed_data_folder, [".shp"])
aoi_vector = filter_files_by_keywords(input_vector_files, ["aoi"])[0]


In [None]:
def calculate_period_plot_risk_map(
    period_dictionary, aoi_vector, model_folder, plots_folder
):
    period_output_folder = model_folder / period_dictionary["period"]
    raster_files = list_files_by_extension(period_output_folder, [".tiff", ".tif"])
    prob_mw_files = filter_files_by_keywords(
        raster_files, ["prob", "mw"], False, ["ldefrate"]
    )
    for prob_mw_file in prob_mw_files:
        ifile = prob_mw_file
        base_name = os.path.splitext(os.path.basename(ifile))[0]
        png_filename = os.path.join(plots_folder, f"{base_name}.png")
        riskmap_fig = rmj.plot.riskmap(
            input_risk_map=str(ifile),
            maxpixels=1e8,
            output_file=png_filename,
            borders=aoi_vector,
            legend=False,
            figsize=(6, 5),
            dpi=100,
            linewidth=0.3,
        )


In [None]:
calculate_period_plot_risk_map(calibration_dict, aoi_vector, rmj_mw, plots_folder)


In [None]:
calculate_period_plot_risk_map(historical_dict, aoi_vector, rmj_mw, plots_folder)


In [None]:
calculate_period_plot_risk_map(validation_dict, aoi_vector, rmj_mw, plots_folder)


In [None]:
calculate_period_plot_risk_map(forecast_dict, aoi_vector, rmj_mw, plots_folder)
