In [None]:
# top functions
def split_dataset_and_save(ds_fp, divisor, output_dir_name: str=None, select_vars: list[str]=None):
    
    ds = xa.open_dataset(ds_fp)
    
    if select_vars:
        ds = ds[select_vars]
            
    subsets_dict = split_dataset_by_indices(ds, divisor)
    
    # Create a subdirectory to save the split datasets
    if output_dir_name:
        output_dir = Path(ds_fp).parent / f"{output_dir_name}_{divisor**2}_split_datasets"
    else:
        output_dir = Path(ds_fp).parent / f"{divisor**2}_split_datasets"
    output_dir.mkdir(parents=True, exist_ok=True)

    
    for coord_info, subset in tqdm(subsets_dict.items(), desc="saving dataset subsets..."):
        stem_stem = str(Path(ds_fp).stem).split("lats")[0]
        # Construct the filename based on bounds
        filename = f"{stem_stem}_{coord_info}.nc"
        save_fp = output_dir / filename
        subset.to_netcdf(save_fp)
    return subsets_dict
    

def split_dataset_by_indices(dataset, divisor) -> dict:
    subsets_dict = {}
    num_lats = len(dataset.latitude.values) // divisor
    num_lons = len(dataset.longitude.values) // divisor
    for i in range(divisor):
        for j in range(divisor):
            start_lat_ind = i * num_lats
            start_lon_ind = j * num_lons
            
            subset = dataset.isel(latitude=slice(start_lat_ind, start_lat_ind + num_lats),
                                  longitude=slice(start_lon_ind, start_lon_ind + num_lons))
            
            lat_lims = spatial_data.min_max_of_coords(subset, "latitude")
            lon_lims = spatial_data.min_max_of_coords(subset, "longitude")
            
            coord_info = functions_creche.tuples_to_string(lat_lims, lon_lims)
            subsets_dict[coord_info] = subset
    
    return subsets_dict


def ds_to_ml_ready(ds, 
    gt:str="unep_coral_presence", exclude_list: list[str]=["latitude", "longitude", "latitude_grid", "longitude_grid", "crs", "depth", "spatial_ref"], 
    train_val_test_frac=[1,0,0], inf_type: str="classification", threshold=0.5, depth_mask_lims = [-50, 0], client=None, remove_rows:bool=False):
    
    df = ds.compute().to_dataframe()
    # TODO: implement checking for empty dfs

    predictors = [pred for pred in df.columns if pred != gt and pred not in exclude_list]
    depth_condition = (df["elevation"] < max(depth_mask_lims)) & (df["elevation"] > min(depth_mask_lims))
    
    if remove_rows:
        df = df[depth_condition]
    else:
        df["within_depth"] = 0
        df.loc[depth_condition, "within_depth"] = 1
        
    if len(df) > 0:
        scaler = MinMaxScaler()
        df = pd.DataFrame(scaler.fit_transform(df), columns=df.columns, index = df.index)
    
    df["nan_onehot"] = df.isna().any(axis=1).astype(int)
    df = df.fillna(0)
    
#     X = df[predictors].to_numpy()
#     y = df[gt].to_numpy()
    
    X = df[predictors]
    y = df[gt]
    
    return X, y

def cont_to_class(array, threshold=0.5):
    array[array >= threshold] = 1
    array[array < threshold] = 0

    return array.astype(int)


def customize_plot_colors(fig, ax, background_color="#212121", text_color="white"):
    # Set figure background color
    fig.patch.set_facecolor(background_color)

    # Set axis background color (if needed)
    ax.set_facecolor(background_color)

    # Set text color for all elements in the plot
    for text in fig.texts:
        text.set_color(text_color)
    for text in ax.texts:
        text.set_color(text_color)
    for text in ax.xaxis.get_ticklabels():
        text.set_color(text_color)
    for text in ax.yaxis.get_ticklabels():
        text.set_color(text_color)
    ax.title.set_color(text_color)
    ax.xaxis.label.set_color(text_color)
    ax.yaxis.label.set_color(text_color)

    # Set legend text color
    legend = ax.get_legend()
    if legend:
        for text in legend.get_texts():
            text.set_color(text_color)

    return fig, ax

In [None]:
# essentials
high_res_ds = xa.open_dataset("data/temp_rf_lats_-32-0_lons_130-170_ds.nc")

num_files = -1
train_nc_fps = list(Path("data/64_split_datasets").glob("temp_rf*.nc"))[:num_files]
# test_nc_fps = list(Path("data/64_split_datasets").glob("temp_rf*.nc"))[num_files:num_files + num_files]


train_Xs, train_ys = [], []
for i, fp in tqdm(enumerate(train_nc_fps), total=len(train_nc_fps)):
    ds = xa.open_dataset(fp)
    X, y = ds_to_ml_ready(ds, remove_rows=True)
#     print(sum(y))
    train_Xs.append(X)
    train_ys.append(y)

X_df = pd.concat(train_Xs, axis=0)
y_df = pd.concat(train_ys, axis=0)

# test_Xs, test_ys = [], []
# for i, fp in enumerate(test_nc_fps):
#     ds = xa.open_dataset(fp)
#     X, y = ds_to_ml_ready(ds, remove_rows=True)
#     test_Xs.append(X)
#     test_ys.append(y)

X = X_df.to_numpy()
y = y_df.to_numpy()

print("X shape:", X.shape)
print("y shape:", y.shape)

In [None]:
# client.close()

# local dask cluster
cluster = LocalCluster(n_workers=4)
client = Client(cluster)

model_type="classification"
n_samples = 10000
threshold = 0.25

y_train = cont_to_class(y_train, 0.25)

class_counts = np.bincount(y_train)
total_samples = len(y_train)

class_weight = {
    0: total_samples / (2 * class_counts[0]),
    1: 5*total_samples / (2 * class_counts[1])
}

# Create the XGBoost DMatrices
dX_train = dask.array.from_array(X_train)
dy_train = dask.array.from_array(y_train).rechunk(dX_train.chunksize[0])

dX_test = dask.array.from_array(X_test)
dy_test = dask.array.from_array(y_test).rechunk(dX_test.chunksize[0])


if model_type == "classification":
    y_train = cont_to_class(y_train, threshold=threshold)
    y_test = cont_to_class(y_test, threshold=threshold)
    
    model = xgb.XGBClassifier(scale_pos_weight=(len(y_train)-sum(y_train))/sum(y_train))
    
elif model_type == "regression":
    model = xgb.XGBRegressor(
        n_estimators=1000, max_depth=7, eta=0.1, subsample=0.7, colsample_bytree=0.8, scale_pos_weight=(len(y_train)-sum(y_train))/sum(y_train)
    )


# dtrain = xgb.dask.DaskDMatrix(client, dX_train, dy_train)
# dtest = xgb.dask.DaskDMatrix(client, dX_test,  dy_test)

param_space = {
    'bootstrap': [True, False],
    'ccp_alpha': [0.0, 0.1, 0.2],
    'class_weight': ['balanced', class_weight],
    'criterion': ['gini', 'entropy'],
    'max_depth': [3, 5, 7, 10, 50, 100],
    'max_features': ['sqrt', 'log2'],
    'max_leaf_nodes': [None, 5, 10, 20],
#     'max_samples': [None, 0.5, 0.7, 0.9],
    'min_impurity_decrease': [0.0, 0.1, 0.2],
    'min_samples_leaf': [1, 2, 4],
    'min_samples_split': [2, 5, 10],
    'min_weight_fraction_leaf': [0.0, 0.1, 0.2],
    'n_estimators': [50, 100, 200],
    'oob_score': [True, False],
    'random_state': [42],  # Add different seed values if desired
#     'warm_start': [True, False]
    # Add more parameters and ranges as needed
}

search = RandomizedSearchCV(model, param_space, cv=3, n_iter=100, verbose=2, n_jobs=-1)

# already saved to pickle
with joblib.parallel_backend('dask'):
    search.fit(X_train[:n_samples], y_train[:n_samples])
    
with open(f"xgb_cl_best_rerun.pickle", "wb") as handle:
    pickle.dump(search.best_params_, handle, protocol=-1)

In [None]:
model = xgb.XGBClassifier(**search.best_params_)

with joblib.parallel_backend('dask'):
    model.fit(X_train[:n_samples], y_train[:n_samples])

# y_test_pred = model.predict_proba(X_test) # 2xN
y_test_pred = model.predict(X_test)

In [None]:
# formatting confusion matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

fig, ax = plt.subplots(dpi=300)

cmap = spatial_plots.get_cbar()


cm = confusion_matrix(y_test, cont_to_class(y_test_pred, 0.5), labels=[0, 1], normalize="true")
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                               display_labels=["absence", "presence"],
                              # , cmap=cmap
                             )

customize_plot_colors(fig,ax)
disp.plot(ax=ax, cmap=cmap, text_kw={"c":"k"}, colorbar=False)

ax.set_xlabel('')
ax.set_ylabel('')

plt.yticks(rotation=90)

balanced_accuracy_score = sklmetrics.balanced_accuracy_score(y_test, y_test_pred)
accuracy_score = sklmetrics.accuracy_score(y_test, y_test_pred)
f1_score = sklmetrics.f1_score(y_test, y_test_pred)

print("balanced_accuracy_score", balanced_accuracy_score)
print("accuracy_score", accuracy_score)
print("f1_score", f1_score)


In [None]:
# multi model run functions
resolutions = [
    # 1, 0.5, 0.1
    0.25, 0.05
]

def get_min_max_coords(ds, coord):
    min_coord = float(min(ds[coord]).values)
    max_coord = float(max(ds[coord]).values)
    return min_coord, max_coord


def process_df_for_ml(
    df: pd.DataFrame, ignore_vars: list[str], drop_all_nans: bool = True
) -> pd.DataFrame:
    # drop ignored vars
    df = df.drop(columns=list(set(ignore_vars).intersection(df.columns)))

    if drop_all_nans:
        # remove rows which are all nans
        df = utils.drop_nan_rows(df)
    # onehot encoode any remaining nans
    df["onehotnan"] = df.isnull().any(axis=1).astype(int)
    # fill nans with 0
    df = df.fillna(0)

    # flatten dataset for row indexing and model training
    return df



def xa_dss_to_df(
    xa_dss: list[xa.Dataset],
    bath_mask: bool = True,
    res: float = 1,
    ignore_vars: list = ["spatial_ref", "band", "depth"],
    drop_all_nans: bool = True,
):
    dfs = []
    
    limits_dict = {
        1: [-2000,0],
        0.5: [-1000,0],
        0.25: [-500,0],
        0.1: [-80,0],
        0.05: [-60,0],
        0.01: [-60,0],
    }
    for xa_ds in xa_dss:
        
        
        if bath_mask:
            print(limits_dict[res])
            # set all values outside of the shallow water region to nan for future omission
            shallow_mask = spatial_data.generate_var_mask(xa_ds,mask_var="elevation",
                                                          limits=limits_dict[res]
                                                         )
            xa_ds = xa_ds.where(shallow_mask, np.nan)

        # compute out dasked chunks, send type to float32, stack into df, drop any datetime columns
        df = (
            xa_ds.stack(points=("latitude", "longitude"))
            .compute()
            .astype("float32")
            .to_dataframe()
        )
        # drop temporal columns
        df = df.drop(columns=list(df.select_dtypes(include="datetime64").columns))
        df = process_df_for_ml(df, ignore_vars=ignore_vars, drop_all_nans=drop_all_nans)

        dfs.append(df)
    return dfs



def generate_split(ds, res):
    
    # flatten datasets to pandas dataframes and process
    flattened_data_dfs = xa_dss_to_df([ds], bath_mask=True, res=res)
    # generate training and testing coordinates

    # normalise dataframe via min/max scaling
    normalised_dfs = [
        (flattened_data - flattened_data.min())
        / (flattened_data.max() - flattened_data.min())
        for flattened_data in flattened_data_dfs
    ]

    y = normalised_dfs[0]["unep_coral_presence"].to_numpy()
    X = normalised_dfs[0].loc[:, normalised_dfs[0].columns != 'unep_coral_presence'].to_numpy()
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    return X_train, X_test, y_train, y_test


def train_model(model_type: str, X_train, y_train, cv=3, n_iter=50, n_samples=-1):
    cluster = LocalCluster(n_workers=4)
    client = Client(cluster)

    if model_type == "classification":
        y_train = cont_to_class(y_train, threshold=threshold)
        y_test = cont_to_class(y_test, threshold=threshold)

        model = xgb.XGBClassifier(scale_pos_weight=(len(y_train)-sum(y_train))/sum(y_train))
    
    elif model_type == "regression":
        model = xgb.XGBRegressor(
            n_estimators=1000, max_depth=7, eta=0.1, subsample=0.7, colsample_bytree=0.8, scale_pos_weight=(len(y_train)-sum(y_train))/sum(y_train)
        )

    class_counts = np.bincount(y_train)
    total_samples = len(y_train)

    class_weight = {
        0: total_samples / (2 * class_counts[0]),
        1: 5*total_samples / (2 * class_counts[1])
    }

    param_space = {
        'bootstrap': [True, False],
        'ccp_alpha': [0.0, 0.1, 0.2],
        'class_weight': ['balanced', class_weight],
        'criterion': ['gini', 'entropy'],
        'max_depth': [3, 5, 7, 10, 50, 100],
        'max_features': ['sqrt', 'log2'],
        'max_leaf_nodes': [None, 5, 10, 20],
    #     'max_samples': [None, 0.5, 0.7, 0.9],
        'min_impurity_decrease': [0.0, 0.1, 0.2],
        'min_samples_leaf': [1, 2, 4],
        'min_samples_split': [2, 5, 10],
        'min_weight_fraction_leaf': [0.0, 0.1, 0.2],
        'n_estimators': [50, 100, 200],
        'oob_score': [True, False],
        'random_state': [42],  # Add different seed values if desired
    #     'warm_start': [True, False]
        # Add more parameters and ranges as needed
    }

    model = xgb.XGBClassifier()
    search = RandomizedSearchCV(model, param_space, cv=cv, n_iter=n_iter, verbose=2, n_jobs=-1)
    
    with joblib.parallel_backend('dask'):
        search.fit(X_train[:n_samples], y_train[:n_samples])
    client.close()
    
    return search.best_params_, search



def models_at_resolutions(high_res_ds, resolutions:list, model_type, cv, n_iter, n_samples, threshold=0.25):
    
    models = []
    best_param_dicts = {}
    balanced_accuracy_scores = []
    accuracy_scores = []
    f1_scores = []
    
    for res in tqdm(resolutions, desc="iterating over resolutions..."):
        lat_range = get_min_max_coords(high_res_ds, "latitude")
        lon_range = get_min_max_coords(high_res_ds, "longitude")
        
        if res == 0.01:
            print(f"bypassing resampling since already at {res}")
            ds_res = high_res_ds
        else:
            ds_res = functions_creche.resample_xa_d(high_res_ds, lat_range, lon_range, res, res)
        
        X_train, X_test, y_train, y_test = generate_split(ds_res, res)
        
        
        if model_type == "classifier":
            y_train = cont_to_class(y_train, threshold=threshold)
            y_test = cont_to_class(y_test, threshold=threshold)
            
            if n_samples > len(X_train):
                print(f"reassigned n_samples to len(X_train):", len(X_train))
                n_samples = len(X_train)
            
            best_params, search = train_model("classifier", X_train, y_train, cv, n_iter, n_samples)
            best_param_dicts[res] = best_params
            
            model = xgb.XGBClassifier(**best_params)
            model.fit(X_train, y_train)
            models.append(model)
                
            y_test_pred = cont_to_class(model.predict(X_test), 0.5)
            balanced_accuracy_score = sklmetrics.balanced_accuracy_score(y_test, y_test_pred)
            accuracy_score = sklmetrics.accuracy_score(cont_to_class(y_test, 0.25), y_test_pred)
            f1_score = sklmetrics.f1_score(cont_to_class(y_test, 0.25), y_test_pred)
            
            with open(f"xgb_cl_standard_{res:.03f}.pickle", "wb") as handle:
                pickle.dump(best_params, handle, protocol=-1)
    
            balanced_accuracy_scores.append(balanced_accuracy_score)
            accuracy_scores.append(accuracy_score)
            f1_scores.append(f1_score)
            print(f"resolution: {res}°")
            print(f"balanced accuracy_score: {balanced_accuracy_score}")
            print(f"accuracy_score: {accuracy_score}")
            print(f"f1_score: {f1_score}")           
        
    return best_param_dicts, models, accuracy_scores, f1_scores, balanced_accuracy_scores

# high_res_ds = xa.open_dataset("data/temp_rf_lats_-32-0_lons_130-170_ds.nc")
# best_param_dicts, models, accuracy_scores, f1_scores, balanced_accuracy_scores = models_at_resolutions(high_res_ds, resolutions, model_type="classifier", cv=5, n_iter=50, n_samples=10000) lutions, model_type="classifier", cv=5, n_iter=50, n_samples=10000

In [None]:
# single resolution run
# TODO: could split off train ds with n_samples earlier to ease processing
res = 0.01
X_train, X_test, y_train, y_test = generate_split(high_res_ds, 0.01)
threshold = 0.25
cv = 5
n_iter = 50
n_samples = 10000
model_type = "classifier"

best_params_dicts = {}
if model_type == "classifier":
    y_train = cont_to_class(y_train[:n_samples], threshold=threshold)
    y_test = cont_to_class(y_test, threshold=threshold)

    if n_samples > len(X_train):
        print(f"reassigned n_samples to len(X_train):", len(X_train))
        n_samples = len(X_train)

    best_params, search = train_model("classifier", X_train[:n_samples], y_train[:n_samples], cv, n_iter, n_samples)
    best_params_dicts[res] = best_params

    model = xgb.XGBClassifier(**best_params)
    model.fit(X_train, y_train)
    models.append(model)

    y_test_pred = cont_to_class(model.predict(X_test), 0.5)
    balanced_accuracy_score = sklmetrics.balanced_accuracy_score(y_test, y_test_pred)
    accuracy_score = sklmetrics.accuracy_score(cont_to_class(y_test, 0.25), y_test_pred)
    f1_score = sklmetrics.f1_score(cont_to_class(y_test, 0.25), y_test_pred)

    with open(f"xgb_cl_standard_{res:.03f}.pickle", "wb") as handle:
        pickle.dump(best_params, handle, protocol=-1)

    balanced_accuracy_scores.append(balanced_accuracy_score)
    accuracy_scores.append(accuracy_score)
    f1_scores.append(f1_score)
    print(f"resolution: {res}°")
    print(f"balanced accuracy_score: {balanced_accuracy_score}")
    print(f"accuracy_score: {accuracy_score}")
    print(f"f1_score: {f1_score}")           
