<a href="https://colab.research.google.com/github/tnc-br/ddf-isoscapes/blob/fb/vi_minimal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational model

Find the mean/variance of O18 ratios (as well as N15 and C13 in the future) at a particular lat/lon across Brazil.

# Import libraries required

In [79]:
import importlib
from datetime import datetime
import sys
import os

!if [ ! -d "/content/ddf_common_stub" ] ; then git clone -b test https://github.com/tnc-br/ddf_common_stub.git; fi
sys.path.append("/content/ddf_common_stub/")
import ddfimport
ddfimport.ddf_source_control_pane()

interactive(children=(Text(value='', description='Email', placeholder='Enter email'), Text(value='', descripti…

In [88]:
import train_variational_inference_model as tvim
import raster
import model
import importlib
import evaluation
import eeddf
import bqddf
importlib.reload(tvim)
importlib.reload(raster)
importlib.reload(model)
importlib.reload(evaluation)
importlib.reload(bqddf)
eeddf.initialize_ddf(test_environment=True)

# Model configuration

In [83]:
params = tvim.VIModelTrainingParams(
    training_id = "test-2024-08-16", #@param
    num_epochs = 5000, #@param
    num_layers = 2, #@param
    num_nodes_per_layer = 20, #@param
    training_batch_size = 5, #@param
    learning_rate = 0.0001, #@param
    mean_label = "d18O_cel_mean", #@param
    var_label = "d18O_cel_variance", #@param
    early_stopping_patience = 100, #@param
    double_sided_kl = False, #@param
    kl_num_samples_from_pred_dist = 15, #@param
    dropout_rate = 0, #@param
    activation_func = "relu", #@param
    features_to_standardize = ['lat', 'long', 'VPD', 'RH', 'PET', 'DEM', 'PA', 'Mean Annual Temperature', 'Mean Annual Precipitation', 'Iso_Oxi_Stack_mean_TERZER', 'isoscape_fullmodel_d18O_prec_REGRESSION', 'brisoscape_mean_ISORIX', 'd13C_cel_mean', 'd13C_cel_var', 'ordinary_kriging_linear_d18O_predicted_mean', 'ordinary_kriging_linear_d18O_predicted_variance'], #@param
    features_to_passthrough = [], #@param
    resolution_x = 1024, #@param
    resolution_y = 1024, #@param
)



In [None]:
# Set to true if you have already done a training run with params.training_id and
# you want to reuse that isoscape from that run.
EVAL_ONLY = True #@param {type:"boolean"}

In [85]:
eval_params = tvim.VIModelEvalParams(
    samples_per_location = 5, #@param
    precision_target = 0.95, #@param
    recall_target = 0.95, #@param
    start_max_fraud_radius= 6, #@param
    end_max_fraud_radius = 3000, #@param
    radius_pace = 100, #@param
    max_fraud_dist = 3000, #@param
    min_trusted_dist = 5, #@param
    elements_to_eval = ['d18O_cel'], #@param
)

# Data configuration

In [86]:
from google.colab import drive
drive.mount(raster.GDRIVE_BASE)


# TRAINING FILE PARAMS
DATABASE_DIR = raster.GDRIVE_BASE + raster.SAMPLE_DATA_BASE
TRAINING_SET_FILE = 'demo_train_fixed_grouped.csv' #@param
VALIDATION_SET_FILE = 'demo_validation_fixed_grouped.csv' #@param
TEST_SET_FILE = 'demo_test_fixed_grouped.csv' #@param

# EVAL FILE PARAMS
EVAL_DATASET = 'demo_test_fixed_grouped.csv' #@param
ORIGINAL_DATASET = '2023_06_23_Results_Google.csv' #@param

fileset = {
    'TRAIN' : os.path.join(DATABASE_DIR, TRAINING_SET_FILE),
    'TEST' : os.path.join(DATABASE_DIR, VALIDATION_SET_FILE),
    'VALIDATION' : os.path.join(DATABASE_DIR, TEST_SET_FILE),
    'EVAL' : os.path.join(DATABASE_DIR, EVAL_DATASET),
    'ORIGINAL' : os.path.join(DATABASE_DIR, ORIGINAL_DATASET)
}



MODEL_SAVE_LOCATION = os.path.join(raster.GDRIVE_BASE, raster.MODEL_BASE, params.training_id + ".keras")
ISOSCAPE_SAVE_LOCATION = raster.get_raster_path(params.training_id+".tiff")

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


# Train the model

In [87]:
res = tvim.train_variational_inference_model(params, eval_params, fileset, ISOSCAPE_SAVE_LOCATION, MODEL_SAVE_LOCATION, eval_only=EVAL_ONLY)

Driver: GTiff/GeoTIFF
Size is 1024 x 1024 x 2
Projection is GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST],AUTHORITY["EPSG","4326"]]
Origin = (-74.0000000000241, 5.29166666665704)
Pixel Size = (0.03828938802082461, -0.03812662760417103)
Driver: GTiff/GeoTIFF
Size is 1024 x 1024 x 2
Projection is GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST],AUTHORITY["EPSG","4326"]]
Origin = (-74.0000000000241, 5.29166666665704)
Pixel Size = (0.03828938802082461, -0.03812662760417103)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  data_sample_size = isotope_counts[i]
  isotope_variance = isotope_variances[i]
  

{'samples_per_location': 5, 'precision_target': 0.95, 'recall_target': 0.95, 'start_max_fraud_radius': 6, 'end_max_fraud_radius': 3000, 'radius_pace': <class 'int'>, 'max_fraud_dist': 3000, 'min_trusted_dist': 5, 'elements_to_eval': ['d18O_cel'], 'mean_rmse': 1.4247102956957474, 'var_rmse': 1.8949035318212868, 'overall_rmse': 1.659806913758517, 'per_radius_eval': [{'radius': 6, 'auc': 0.46617576030083707, 'p_value': 0.05585371082306034, 'precision_target': 0.5555555555555556, 'recall_target': 0.22727272727272727, 'pr_curve': {'precision': [0.5, 0.4883720930232558, 0.5, 0.5121951219512195, 0.5, 0.5128205128205128, 0.5, 0.4864864864864865, 0.5, 0.5142857142857142, 0.5294117647058824, 0.5151515151515151, 0.5, 0.4838709677419355, 0.5, 0.4827586206896552, 0.5, 0.48148148148148145, 0.46153846153846156, 0.44, 0.4166666666666667, 0.43478260869565216, 0.45454545454545453, 0.42857142857142855, 0.4, 0.42105263157894735, 0.4444444444444444, 0.47058823529411764, 0.5, 0.5333333333333333, 0.5, 0.4615

TypeError: Object of type type is not JSON serializable

# Optional Rendering

In [None]:
from matplotlib import rc
rc('animation', html='jshtml')

means_isoscape = raster.load_raster(ISOSCAPE_SAVE_LOCATION, use_only_band_index=0)
raster.animate(means_isoscape,  1, 1)

Driver: GTiff/GeoTIFF
Size is 1024 x 1024 x 2
Projection is GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST],AUTHORITY["EPSG","4326"]]
Origin = (-74.0000000000241, 5.29166666665704)
Pixel Size = (0.03828938802082461, -0.03812662760417103)
..

In [None]:
vars_isoscape = raster.load_raster(ISOSCAPE_SAVE_LOCATION, use_only_band_index=1)
raster.animate(vars_isoscape,  1, 1)

Driver: GTiff/GeoTIFF
Size is 1024 x 1024 x 2
Projection is GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST],AUTHORITY["EPSG","4326"]]
Origin = (-74.0000000000241, 5.29166666665704)
Pixel Size = (0.03828938802082461, -0.03812662760417103)
..