In [3]:
import os, sys, time
from datetime import datetime
import pandas as pd
import numpy as np
import geojson, json
import xarray as xr
import rioxarray

from sklearn.cross_decomposition import CCA
from sklearn.decomposition import PCA

import importlib
from sklearn.preprocessing import StandardScaler

from sklearn.linear_model import LinearRegression, Ridge, Lasso
#from sklearn.linear_model import RidgeCV, LassoCV
from sklearn.tree import DecisionTreeRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.ensemble import RandomForestRegressor

from sklearn.model_selection import cross_val_score, RepeatedKFold, LeaveOneOut, LeavePOut, KFold, cross_val_predict

from sklearn.metrics import r2_score, mean_squared_error, roc_auc_score, mean_absolute_percentage_error, mean_squared_error, explained_variance_score

from sklearn.base import BaseEstimator, RegressorMixin

from dev_functions.dev_functions import *

from rasterstats import zonal_stats

from matplotlib import pyplot as plt
import matplotlib.colors as colors

import cartopy.crs as ccrs

import gl

seasons = ['JFM', 'FMA', 'MAM', 'AMJ', 'MJJ', 'JJA', 'JAS', 'ASO', 'SON', 'OND', 'NDJ', 'DJF']


msgColors={"ERROR": "red",
           "INFO":"blue",
           "RUNTIME":"grey",
           "NONCRITICAL":"red",
           "SUCCESS":"green"
          }

def showMessage(_message, _type="RUNTIME"):
    #this print messages to log window, which are generated outside of the threaded function
    print(_message)

    
#cross-validation approaches to chose from
cvs={"LOO":LeaveOneOut(),
     "KF":KFold(n_splits=5)
}

#can be read from json - potentially editable by user
regressor_configs = {
    "Linear regression": {},
    "Lasso regression": {'alpha': 0.01},
    "Ridge regression": {'alpha': 1.0},
    "Random Forest": {'n_estimators': 100, 'max_depth': 5},
    "MPL": {'hidden_layer_sizes': (50, 25), 'max_iter': 1000, 'random_state': 0},
    'Decision Trees': {'max_depth': 2}
}



class StopEarly(Exception):
    pass

gl.config={}
gl.config["climEndYr"]=2015
gl.config["climStartYr"]=1994
#gl.config['predictorFileList'] = [["./data/SEAS51.SST.nc","sst"]]
gl.config['predictorFileList'] = [["./data/SST_Jun_1960-2025.nc","sst"]]
gl.config['predictorMonth'] = "Jun"
gl.config['basinMinLat'] = -90
gl.config['basinMaxLat'] = 90
gl.config['basinMinLon'] = -180
gl.config['basinMaxLon'] = -180


gl.config["predictandFileName"]="./data/PRCPTOT_mon_CHIRPS-v2.0-p05-merged_cft_stations_BWA.csv"
gl.config["predictandFileName"]="./data/predictand_test_format.csv"
gl.config["predictandFileFormat"]="csv"

gl.config["predictandFileName"]="./data/pr_mon_chirps-v2.0_198101-202308.nc"
gl.config["predictandFileFormat"]="netcdf"
gl.config["predictandVar"]="PRCPTOT"

gl.config['fcstBaseTime']="seas"
gl.config['temporalAggregation']="mean"
gl.config['fcstTargetMonth']="Dec"
gl.config['fcstTargetYear']=2025
gl.config['predictandCategory'] = 'pr'
gl.config['predictandName'] = 'rainfall'
gl.config['predictandMissingValue']=-999

gl.config['method']="PCR"
gl.config['expVariance']=0.8

gl.config['model']="Linear regression"
gl.config['crossVal']="KF"

gl.config["zonesMap"]="data/Botswana.geojson"
gl.config["spatialAggregate"]=True

gl.config["zonesID"]="ID"
gl.maxLeadTime=6
gl.maxPercentPCsRetain=15
gl.config['rootDir']="../forecast"

#set target type 
if gl.config["spatialAggregate"]:
    gl.config["targetType"]="zones"
elif gl.config["predictandFileFormat"]=="csv":
    gl.config["targetType"]="points"
else:
    gl.config["targetType"]="grid"
gl.config["targetType"]

'zones'

In [4]:
#this is the main processing stream
try:
    leadTime=getLeadTime()
    if leadTime is None:
        raise StopEarly
    
    #reading predictors data
    predictors=readPredictors()
    if predictors is None:
        raise StopEarly
        
    #reading predictand data - this will calculate seasonal from monthly if needed.
    predictand, geoData=readPredictand()
    if predictand is None:
        raise StopEarly
    
    if gl.config["zonesMap"]!="":
        zonesVector=gpd.read_file(gl.config["zonesMap"])
    else:
        zonesVector=None
        
    if gl.config["targetType"]=="zones":
        showMessage("Aggregating data to zones read from {} ...".format(gl.config["zonesMap"]))
        cont=True
        predictand,geoData=aggregatePredictand(predictand, geoData, zonesVector)
        
    #defining target date for forecast. If seasonal - then this is the first month of the season.
    fcstTgtDate=pd.to_datetime("01 {} {}".format(gl.config['fcstTargetMonth'], gl.config['fcstTargetYear']))
    
    gl.config["fcstTgtCode"]=seasons[fcstTgtDate.month-1]
    #will have to implement iteration through predictors?? for the time being - just a single predictor
    
    #finding overlap of predictand and predictor
    showMessage("Aligning predictor and predictand data...")
    predictandHcst,predictorHcst=getHcstData(predictand,predictors[0])
    predictorFcst=getFcstData(predictors[0])
    if predictandHcst is None:
        raise StopEarly

    showMessage("Setting up directories to write to...")        
    if gl.config['fcstBaseTime']=="seas":
        forecastID="{}-{}".format(gl.predictorDate.strftime("%Y%m"), seasons[fcstTgtDate.month-1])
    else:
        forecastID="{}-{}".format(gl.predictorDate.strftime("%Y%m"), fcstTgtDate.strftime("%b"))
    forecastDir="{}/{}/{}".format(gl.config['rootDir'], forecastID, gl.config['targetType'])

    mapsDir="{}/maps/".format(forecastDir)
    timeseriesDir="{}/timeseries/".format(forecastDir)
    outputDir="{}/output/".format(forecastDir)
    diagsDir="{}/diagnostics/".format(forecastDir)

    for adir in [mapsDir,outputDir, diagsDir,timeseriesDir]:
        if not os.path.exists(adir):
            print("\toutput directory {} does not exist. creating...".format(adir))
            os.makedirs(adir)
            print("\tdone")
    showMessage("done")        
        
        
    #calculaing observed terciles
    #is there a need to do a strict control of overlap???
    result=getObsTerciles(predictand, predictandHcst)
    if result is None:
        raise StopEarly
    obsTercile,tercThresh=result


    #setting up cross-validation
    cv=cvs[gl.config['crossVal']]
    
    
    #arguments for regressor
    kwargs=regressor_configs[gl.config['model']]

    if gl.config['method']=="PCR":
        #regession model
        regressor = PCRegressor(regressor_name=gl.config['model'], **kwargs)
    if gl.config['method']=="CCA":
        regressor = CCAregressor(regressor_name=gl.config['model'], **kwargs)
        
  
    #cross-validated hindcast
    showMessage("Calculating cross-validated hindcast...")
    cvHcst = cross_val_predict(regressor,predictorHcst,  predictandHcst, cv=cv)
    cvHcst=pd.DataFrame(cvHcst, index=predictandHcst.index, columns=predictandHcst.columns)

    
    #actual prediction
    showMessage("Calculating deteriministic forecast...")
    regressor.fit(predictorHcst,  predictandHcst)
    detFcst=regressor.predict(predictorFcst)
    detFcst=pd.DataFrame(detFcst, index=[fcstTgtDate], columns=predictandHcst.columns)
    
    #calculate forecast anomalies
    refData=predictand[str(gl.config["climStartYr"]):str(gl.config["climEndYr"])]   
    detFcst=getFcstAnomalies(detFcst,refData)
    
    #deriving probabilistic prediction
    showMessage("Calculating probabilistic hindcast and forecast using error variance...")
    result=probabilisticForecast(cvHcst, predictandHcst,detFcst["forecast"],tercThresh)
    if result is None:
        raise StopEarly
    probFcst,probHcst=result
    showMessage("Hindcast and forecast calculated.")
    
    
    #calculating skill
    showMessage("Calculating skill scores...")
    scores=getSkill(probHcst,cvHcst,predictandHcst,obsTercile)    
    if scores is None:
        raise StopEarly
    
    #saving data
    
    
    showMessage("Plotting forecast maps...")    
    #plotting forecast
    if gl.config["targetType"]=="grid":
        detfcst=detFcst.stack(level=["lat","lon"],future_stack=True).droplevel(0).T
        probfcst=probFcst.stack(level=["lat","lon"],future_stack=True).droplevel(0).T
    else:
        detfcst=detFcst.stack(future_stack=True).droplevel(0).T
        probfcst=probFcst.stack(future_stack=True).droplevel(0).T
        
    plotMaps(detfcst, geoData, mapsDir, forecastID, zonesVector)
    plotMaps(probfcst, geoData, mapsDir, forecastID, zonesVector)
    
    showMessage("Plotting skill maps...")    
    #plotting skill scores
    plotMaps(scores, geoData, mapsDir, forecastID, zonesVector)

    showMessage("Plotting time series...") 
    plotTimeSeries(cvHcst,predictandHcst, detFcst, tercThresh, timeseriesDir, forecastID)
    
    showMessage("All done!")
    
except (TypeError, ValueError) as e:
    print(f"An error occurred: {e}")
    # Alternatively, to get the line number specifically:
    exc_type, exc_value, exc_traceback = sys.exc_info()
    if exc_traceback:
        line_number = exc_traceback.tb_lineno
        print(f"Error occurred on line: {line_number}")

except StopEarly:
    print("Execution stopped")    

reading predictor from ./data/SST_Jun_1960-2025.nc...
	file exists, reading...
	found X - renaming to lon
	found Y - renaming to lat
	found T - renaming to time
	Dropping redundand dimension of size 1: zlev
	Found units: Celsius_scale
	Netcdf file covers period of: 1960-06-01 to 2025-06-01
done

Further processing predictor data...
done

Reading predictand from ./data/pr_mon_chirps-v2.0_198101-202308.nc...
	file exists, reading...
	Netcdf file covers period of: 1981-01-01 to 2023-08-01
done

Resampling to seasonal...
done

Aggregating data to zones read from data/Botswana.geojson ...
aggregating...
	Average values for 4 regions derived from data for 84 by 81 grid
Aligning predictor and predictand data...
Setting up directories to write to...
done
Calculating observed terciles...
Calculating cross-validated hindcast...
Calculating deteriministic forecast...
Calculating probabilistic hindcast and forecast using error variance...
(4, 1) (42, 4)
Hindcast and forecast calculated.
Calculatin

In [46]:
#to do
# include predictor code in output directory structure
# download function
# data saving
# extend and clean up plotting
# test data ingestion errorrs
# GUI
# multimodel forecast