# Evaluate model results

In [None]:
import numpy as np
import os, fnmatch
import pandas as pd
import torch
import xarray as xr

%matplotlib inline
from IPython.display import set_matplotlib_formats
import matplotlib.pyplot as plt

import findspark
findspark.init('/p/system/packages/spark/2.3.0')

from lib.spark import ModelData
from pyspark.sql.functions import col

from src.inference import model_inference, get_baseline, evaluate_models, load_prediction
import src.evaluation_geographic as fa
from src.inference import InferenceConfig

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Running on {str(device).upper()}.')
if torch.cuda.is_available(): print(torch.cuda.get_device_name())

In [None]:
def save_prediction(name, prediction, target, baseline):
    dir='/path/to/models'
    tmp = np.stack([prediction, target, baseline])
    np.save(dir+'/'+name, tmp)
    print('saved at ', dir+'/'+name+'.npy')

In [None]:
model_name = 'UNet-mse'
training_output_path = '/path/to/training_statistics'
config = InferenceConfig(md, [None], '3hourly', 'unet',  'test')

config.data_parallel = False
config.input_format = 'netcdf'
config.num_workers = 0
config.device = 'cuda'
config.batch_size = 8

md = ModelData()
md.import_data(training_output_path)

model_stats = md.get_training_data(model_name, '2021/01/01', '2021/12/31')
model_stats = model_stats.sort_values(['date']).head(50)
display(model_stats)

In [None]:
mask_threshold = 0.0
clean_threshold = 0.1
prediction_dict = {}

for i, uuid in enumerate(model_stats['id']):
        
    print(f"{i}/{len(models)}")
    
    options = md.get_training_options(uuid)
    config.model_id = uuid
    
    print("RUN:")
    print("Features:", options['features'])
        
    lats, lons, prediction, baseline, target,  _ = evaluate_models(config)
    prediction_dict[options['model_name']] = prediction
    
    model_name = md.get_training_options(uuid)['model_name']
    save_prediction(model_name, prediction, target, baseline)
    
    eval = fa.GeographicValidation(lats, lons)
    
    metrics_list = ['RMSE', 'Bias']
    _, _, _ = eval.compute_metrics(metrics_list, prediction, baseline, target,
                                   mask_threshold=mask_threshold,
                                   clean_threshold=clean_threshold,
                                   verbose=True)