<a href="https://colab.research.google.com/github/tnc-br/ddf-isoscapes/blob/validation_pipeline_rmse/validation_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Validation Pipeline

In [None]:
ISOSCAPE_FILENAME = "uc_davis_d18O_cel_kriging.tiff" #@param
# Used in unit tests (generated from Kriging)
TEST_ISOSCAPE_FILENAME = "test_isoscape.tiff" #@param

TEST_SET_FILENAME = 'uc_davis_2023_08_12_test_random_grouped.csv' #@param
# Columns of values to read ground truths from. Invalid values are 'truth'
# and 'prediction'.
MEAN_TRUTH_NAME = 'd18O_cel_mean' #@param
VAR_TRUTH_NAME = 'd18O_cel_variance' #@param
# Columns of values to write temporary predictions to (for RMSE calculation).
# Invalid values are 'truth' and 'prediction'.
MEAN_PREDICTED_NAME = 'd18O_predicted_mean' #@param
VAR_PREDICTED_NAME = 'd18O_predicted_variance' #@param

assert('truth' not in [TEST_SET_FILENAME, MEAN_TRUTH_NAME, VAR_TRUTH_NAME,
                       MEAN_PREDICTED_NAME, VAR_PREDICTED_NAME])
assert('prediction' not in [TEST_SET_FILENAME, MEAN_TRUTH_NAME, VAR_TRUTH_NAME,
                       MEAN_PREDICTED_NAME, VAR_PREDICTED_NAME])

# Imports TODO: Replace with ddfcommon

In [None]:
from osgeo import gdal, gdal_array
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
import matplotlib.animation as animation
from matplotlib import rc
from typing import List
from numpy.random import MT19937, RandomState, SeedSequence
import pandas as pd
from tqdm import tqdm
from io import StringIO
import xgboost as xgb
import os
import math
import glob

rc('animation', html='jshtml')



In [None]:
# Raster directory. Contains:
# iso_O_cellulose.tif: Isoscape of 18O from Precipitation; <-- MODELING TARGET
# Iso_Oxi_Stack.tif: Isoscape of 18O from Precipitation; <-- Model input
# R.rh_Stack.tif: Atmospheric Relative humidity <-- Model input
# R.vpd_Stack.tif: Vapor Pressure Deficit - VPD <-- Model input
# Temperature_Stack.tif: Atmospheric Temperature <-- Model input
RASTER_BASE = "/MyDrive/amazon_rainforest_files/amazon_rasters/" #@param
MODEL_BASE = "/MyDrive/amazon_rainforest_files/amazon_isoscape_models/" #@param
SAMPLE_DATA_BASE = "/MyDrive/amazon_rainforest_files/amazon_sample_data/" #@param
TEST_DATA_BASE = "/MyDrive/amazon_rainforest_files/amazon_test_data/" #@param
BIOME_DATA_PATH = "/MyDrive/amazon_rainforest_files/christian_files/lm_bioma_250.shp" #@param
GDRIVE_BASE = "/content/drive" #@param

# How often should XGB log training metadata? 0 is the default, which indicates never.
XGB_VERBOSITY_LEVEL = 0 #@param

In [None]:
@dataclass
class AmazonGeoTiff:
  """Represents a geotiff from our dataset."""
  gdal_dataset: gdal.Dataset
  image_value_array: np.ndarray # ndarray of floats
  image_mask_array: np.ndarray # ndarray of uint8
  masked_image: np.ma.masked_array
  yearly_masked_image: np.ma.masked_array

@dataclass
class Bounds:
  """Represents geographic bounds and size information."""
  minx: float
  maxx: float
  miny: float
  maxy: float
  pixel_size_x: float
  pixel_size_y: float
  raster_size_x: float
  raster_size_y: float

  def to_matplotlib(self) -> List[float]:
    return [self.minx, self.maxx, self.miny, self.maxy]

@dataclass
class PartitionedDataset:
  train: pd.DataFrame
  test: pd.DataFrame
  validation: pd.DataFrame

In [None]:
def get_raster_path(filename: str) -> str:
  return f"{GDRIVE_BASE}{RASTER_BASE}{filename}"

def get_model_path(filename: str) -> str:
  return f"{GDRIVE_BASE}{MODEL_BASE}{filename}"

def get_sample_db_path(filename: str) -> str:
  return f"{GDRIVE_BASE}{SAMPLE_DATA_BASE}{filename}"

In [None]:
def print_raster_info(raster):
  dataset = raster
  print("Driver: {}/{}".format(dataset.GetDriver().ShortName,
                              dataset.GetDriver().LongName))
  print("Size is {} x {} x {}".format(dataset.RasterXSize,
                                      dataset.RasterYSize,
                                      dataset.RasterCount))
  print("Projection is {}".format(dataset.GetProjection()))
  geotransform = dataset.GetGeoTransform()
  if geotransform:
      print("Origin = ({}, {})".format(geotransform[0], geotransform[3]))
      print("Pixel Size = ({}, {})".format(geotransform[1], geotransform[5]))

  for band in range(dataset.RasterCount):
    band = dataset.GetRasterBand(band+1)
    #print("Band Type={}".format(gdal.GetDataTypeName(band.DataType)))

    min = band.GetMinimum()
    max = band.GetMaximum()
    if not min or not max:
        (min,max) = band.ComputeRasterMinMax(False)
    #print("Min={:.3f}, Max={:.3f}".format(min,max))

    if band.GetOverviewCount() > 0:
        print("Band has {} overviews".format(band.GetOverviewCount()))

    if band.GetRasterColorTable():
        print("Band has a color table with {} entries".format(band.GetRasterColorTable().GetCount()))

def load_raster(path: str, use_only_band_index: int = -1) -> AmazonGeoTiff:
  """
  TODO: Refactor (is_single_band, etc., should be a better design)
  --> Find a way to simplify this logic. Maybe it needs to be more abstract.
  """
  dataset = gdal.Open(path, gdal.GA_ReadOnly)
  print(dataset)
  print_raster_info(dataset)
  image_datatype = dataset.GetRasterBand(1).DataType
  mask_datatype = dataset.GetRasterBand(1).GetMaskBand().DataType
  image = np.zeros((dataset.RasterYSize, dataset.RasterXSize, 12),
                  dtype=gdal_array.GDALTypeCodeToNumericTypeCode(image_datatype))
  mask = np.zeros((dataset.RasterYSize, dataset.RasterXSize, 12),
                  dtype=gdal_array.GDALTypeCodeToNumericTypeCode(image_datatype))

  if use_only_band_index == -1:
    if dataset.RasterCount != 12 and dataset.RasterCount != 1:
      raise ValueError(f"Expected 12 raster bands (one for each month) or one annual average, but found {dataset.RasterCount}")
    if dataset.RasterCount == 1:
      use_only_band_index = 0

  is_single_band = use_only_band_index != -1

  if is_single_band and use_only_band_index >= dataset.RasterCount:
    raise IndexError(f"Specified raster band index {use_only_band_index}"
    f" but there are only {dataset.RasterCount} rasters")

  for band_index in range(12):
    band = dataset.GetRasterBand(use_only_band_index+1 if is_single_band else band_index+1)
    image[:, :, band_index] = band.ReadAsArray()
    mask[:, :, band_index] = band.GetMaskBand().ReadAsArray()
  masked_image = np.ma.masked_where(mask == 0, image)
  yearly_masked_image = masked_image.mean(axis=2)

  return AmazonGeoTiff(dataset, image, mask, masked_image, yearly_masked_image)

def get_extent(dataset):
  geoTransform = dataset.GetGeoTransform()
  minx = geoTransform[0]
  miny = geoTransform[3]
  maxx = minx + geoTransform[1] * dataset.RasterXSize
  maxy = miny + geoTransform[5] * dataset.RasterYSize
  return Bounds(minx, maxx, miny, maxy, geoTransform[1], geoTransform[5], dataset.RasterXSize, dataset.RasterYSize)

def plot_band(geotiff: AmazonGeoTiff, month_index, figsize=None):
  if figsize:
    plt.figure(figsize=figsize)
  im = plt.imshow(geotiff.masked_image[:,:,month_index], extent=get_extent(geotiff.gdal_dataset).to_matplotlib(), interpolation='none')
  plt.colorbar(im)

def animate(geotiff: AmazonGeoTiff, nSeconds, fps):
  fig = plt.figure( figsize=(8,8) )

  months = []
  labels = []
  for m in range(12):
    months.append(geotiff.masked_image[:,:,m])
    labels.append(f"Month: {m+1}")
  a = months[0]
  extent = get_extent(geotiff.gdal_dataset).to_matplotlib()
  ax = fig.add_subplot()
  im = fig.axes[0].imshow(a, interpolation='none', aspect='auto', extent = extent)
  txt = fig.text(0.3,0,"", fontsize=24)
  fig.colorbar(im)

  def animate_func(i):
    if i % fps == 0:
      print( '.', end ='' )

    im.set_array(months[i])
    txt.set_text(labels[i])
    return [im, txt]

  anim = animation.FuncAnimation(
                                fig,
                                animate_func,
                                frames = nSeconds * fps,
                                interval = 1000 / fps, # in ms
                                )
  plt.close()

  return anim

def save_numpy_to_geotiff(bounds: Bounds, prediction: np.ma.MaskedArray, path: str):
  """Copy metadata from a base geotiff and write raster data + mask from `data`"""
  driver = gdal.GetDriverByName("GTiff")
  metadata = driver.GetMetadata()
  if metadata.get(gdal.DCAP_CREATE) != "YES":
      raise RuntimeError("GTiff driver does not support required method Create().")
  if metadata.get(gdal.DCAP_CREATECOPY) != "YES":
      raise RuntimeError("GTiff driver does not support required method CreateCopy().")

  dataset = driver.Create(path, bounds.raster_size_x, bounds.raster_size_y, prediction.shape[2], eType=gdal.GDT_Float64)
  dataset.SetGeoTransform([bounds.minx, bounds.pixel_size_x, 0, bounds.maxy, 0, bounds.pixel_size_y])
  dataset.SetProjection('GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0],UNIT["degree",0.0174532925199433],AUTHORITY["EPSG","4326"]]')

  #dataset = driver.CreateCopy(path, base.gdal_dataset, strict=0)
  if len(prediction.shape) != 3 or prediction.shape[0] != bounds.raster_size_x or prediction.shape[1] != bounds.raster_size_y:
    raise ValueError("Shape of prediction does not match base geotiff")
  #if prediction.shape[2] > base.gdal_dataset.RasterCount:
  #  raise ValueError(f"Expected fewer than {dataset.RasterCount} bands in prediction but found {prediction.shape[2]}")

  prediction_transformed = np.flip(np.transpose(prediction, axes=[1,0,2]), axis=0)
  for band_index in range(dataset.RasterCount):
    band = dataset.GetRasterBand(band_index+1)
    if band.CreateMaskBand(0) == gdal.CE_Failure:
      raise RuntimeError("Failed to create mask band")
    mask_band = band.GetMaskBand()
    band.WriteArray(np.choose(prediction_transformed[:, :, band_index].mask, (prediction_transformed[:, :, band_index].data,np.array(band.GetNoDataValue()),)))
    mask_band.WriteArray(np.logical_not(prediction_transformed[:, :, band_index].mask))

def coords_to_indices(bounds: Bounds, x: float, y: float):
  if x < bounds.minx or x > bounds.maxx or y < bounds.miny or y > bounds.maxy:
    raise ValueError("Coordinates out of bounds")

  # X => lat, Y => lon
  x_idx = bounds.raster_size_y - int(math.ceil((y - bounds.miny) / abs(bounds.pixel_size_y)))
  y_idx = int((x - bounds.minx) / abs(bounds.pixel_size_x))

  return x_idx, y_idx

def test_coords_to_indices():
  bounds = Bounds(50, 100, 50, 100, 1, 1, 50, 50)
  x, y = coords_to_indices(bounds, 55, 55)
  assert x == 45
  assert y == 5

  bounds = Bounds(-100, -50, -100, -50, 1, 1, 50, 50)
  x, y = coords_to_indices(bounds, -55, -55)
  assert x == 5
  assert y == 45

  bounds = Bounds(-10, 50, -10, 50, 1, 1, 60, 60)
  x, y = coords_to_indices(bounds, -1, 13)
  assert x == 37
  assert y == 9

  bounds = Bounds(minx=-73.97513931345594, maxx=-34.808472803053895, miny=-33.73347244751509, maxy=5.266527396029211, pixel_size_x=0.04166666650042771, pixel_size_y=-0.041666666499513144, raster_size_x=937, raster_size_y=941)
  x, y = coords_to_indices(bounds, -67.14342073173958, -7.273271869467912e-05)
  #print(x)
  assert x == 131 # was: 132
  assert y == 163

test_coords_to_indices()

def get_data_at_coords(dataset: AmazonGeoTiff, x: float, y: float, month: int) -> float:
  # x = longitude
  # y = latitude
  bounds = get_extent(dataset.gdal_dataset)
  x_idx, y_idx = coords_to_indices(bounds, x, y)
  if month == -1:
    value = dataset.yearly_masked_image[x_idx, y_idx]
  else:
    value = dataset.masked_image[x_idx, y_idx, month]
  if np.ma.is_masked(value):
    raise ValueError("Coordinates are masked")
  else:
    return value

In [None]:
# Access data stored on Google Drive
if GDRIVE_BASE:
    from google.colab import drive
    drive.mount(GDRIVE_BASE)

# Isoscape: Calculate RMSE

In [None]:
from sklearn.metrics import mean_squared_error

In [None]:
def calculate_rmse(df, means_isoscape, vars_isoscape, mean_true_name, var_true_name, mean_pred_name, var_pred_name):
  '''
  Calculates the mean, variance and overall (mean and variance) RMSE of df using
  the provided columns mean_true_name, var_true_name, mean_pred_name, var_pred_name
  can take any value except 'truth' and 'prediction'
  '''
  # Make sure names do not collide.
  assert(
      len([mean_true_name, var_true_name, mean_pred_name, var_pred_name, 'truth', 'prediction']) ==
      len(set([mean_true_name, var_true_name, mean_pred_name, var_pred_name, 'truth', 'prediction'])))

  df[mean_pred_name] = df.apply(lambda row:get_data_at_coords(means_isoscape, row['long'],row['lat'],-1), axis=1)
  df[var_pred_name] = df.apply(lambda row:get_data_at_coords(vars_isoscape, row['long'],row['lat'],-1), axis=1)

  print(df.columns)

  df['prediction'] = df.apply(lambda row: [row[mean_pred_name], row[var_pred_name]], axis=1)
  df['truth'] = df.apply(lambda row: [row[mean_true_name], row[var_true_name]], axis=1)

  y_pred = list(df['prediction'].values)
  y_true = list(df['truth'].values)

  return (mean_squared_error(df[mean_true_name].values, df[mean_pred_name].values, squared=False),
         mean_squared_error(df[var_true_name].values, df[var_pred_name].values, squared=False),
         mean_squared_error(y_true, y_pred, squared=False))

In [None]:
import pytest

def test_calculate_rmse():
  test_means_isoscape = load_raster(get_raster_path(TEST_ISOSCAPE_FILENAME), use_only_band_index=0)
  test_vars_isoscape = load_raster(get_raster_path(TEST_ISOSCAPE_FILENAME), use_only_band_index=1)
  df = pd.DataFrame({
      'lat': [-3, -4],
      'long': [-55, -54],
      'd18O_cel_mean': [0, 5],
      'd18O_cel_var': [1, 0.5]
  })
  mean_true_name = 'd18O_cel_mean'
  var_true_name = 'd18O_cel_var'
  mean_pred_name = 'd18O_cel_mean_pred'
  var_pred_name = 'd18O_cel_var_pred'
  truth_name = 'd18O_cel_truth'
  pred_name = 'd18O_cel_pred'

  mean_rmse, var_rmse, overall_rmse = calculate_rmse(
      df, test_means_isoscape, test_vars_isoscape,
      mean_true_name, var_true_name, mean_pred_name, var_pred_name)

  assert(mean_rmse == pytest.approx(22.524833655412866))
  assert(var_rmse == pytest.approx(11.00233730921582))
  assert(overall_rmse == pytest.approx(16.763585482314344))

test_calculate_rmse()

In [None]:
print(get_raster_path(ISOSCAPE_FILENAME))
means_isoscape = load_raster(get_raster_path(ISOSCAPE_FILENAME), use_only_band_index=0)
vars_isoscape = load_raster(get_raster_path(ISOSCAPE_FILENAME), use_only_band_index=1)

In [None]:
eval_dataset = pd.read_csv(get_sample_db_path(TEST_SET_FILENAME), index_col=0)
eval_dataset.head()

In [None]:
mean_rmse, var_rmse, overall_rmse = calculate_rmse(eval_dataset, means_isoscape, vars_isoscape, MEAN_TRUTH_NAME, VAR_TRUTH_NAME, MEAN_PREDICTED_NAME, VAR_PREDICTED_NAME)

In [None]:
print("RMSE of Means:", mean_rmse)
print("RMSE of Vars:", var_rmse)
print("Overall RMSE:", overall_rmse)

# TODO: Fraud Detection Hypothesis Test