In [1]:
import torch
import torch.nn as nn
from enso.hyperparams_and_args import get_argparser
from ninolearn.IO.read_processed import data_reader

from MTGNN_multi_channel.data_handler_single_index_CNN_data import IndexLoader
from MTGNN.train_single_step import main, evaluate
from enso.utils import read_ssta
from enso.plotting import plot_time_series, heatmap_of_edges
import warnings
warnings.filterwarnings(action='once')

In [2]:
TRANSFER = True
ONLY_PLOT_EDGES = True
data_dir = r'C:/Users/salva/OneDrive/Documentos/Projects/ProjectX/Data/'
save_plots_to = r"C:/Users/salva/OneDrive/Documentos/Projects/ProjectX/Plots/edges/E33"  # choose None, to not save them
save_plots_to = None

  and should_run_async(code)


In [3]:
train_dates=("1871-01", "1972-12")
val_dates=("1973-01", "1983-12")
test_dates=("1984-01", "2020-08")
parser = get_argparser(experiment="ERSSTv5")
args = parser.parse_args("")
args.data_dir = data_dir
args.lon_min = 190
args.lon_max = 240
args.lat_min = -5
args.lat_max = 5
args.window = 3
args.epochs = 50
args.layers = 2
args.prelu = True

In [5]:
def train_exp1(args, reader):
    print(f"#Lead months = {args.horizon}, #layers = {args.layers}, Inputs between "
      f"{args.lat_min} to {args.lat_max} latitude and {args.lon_min} to {args.lon_max} longitude "
      f"for the past {args.window} months")

    flattened_ssta, train_mask = read_ssta(index=args.index, get_mask=True, resolution=args.resolution, reader=reader)
    args.mask = train_mask

    print("Model will be saved in", args.save)
    main(args, adj=None, train_dates=train_dates, val_dates=val_dates, test_dates=test_dates)  # learn edges

In [6]:
def eval_model(args, test_set="ERSSTv5"):
    model = torch.load(args.save)
    model.eval()

    evaluateL2 = nn.MSELoss().to(args.device)
    evaluateL1 = nn.L1Loss().to(args.device)

    Data = IndexLoader(args, test_set=test_set, start_date="1984-01", end_date="2020-08", data_dir=data_dir)
    test_acc, test_rae, test_corr, oni_test_stats, preds, Ytrue = evaluate(Data, Data.test[0], Data.test[1],
                                                                           model,evaluateL2, evaluateL1, args,
                                                                           return_oni_preds=True)
    print(test_set, " test stats... OVERALL: rse {:5.4f} , RMSE {:5.4f} , corr {:5.4f}"
                    " | ONI:  RMSE {:5.4f} , corr {:5.4f}"
          .format(test_acc, test_rae, test_corr, oni_test_stats["RMSE"], oni_test_stats["Corrcoef"]))
    return preds, Ytrue, Data.semantic_time_steps

In [7]:
def train_and_eval(args):
    reader = data_reader(startdate=args.start_date, enddate=args.end_date,
                         lon_min=args.lon_min, lon_max=args.lon_max,
                         lat_min=args.lat_min, lat_max=args.lat_max)

    train_exp1(args, reader=reader)

    ## Eval test scores
    preds, Y, time_axis = eval_model(args, "ERSSTv5")
    _, _, _ = eval_model(args, "GODAS")

    ## Plot learned edges
    heatmap_of_edges(file_path, args=args, reader=reader, index=args.index, min_weight=1e-4, data_dir=args.data_dir)


    plot_time_series(Y, preds, time_steps=time_axis, labels=["ERSSTv5 ONI", "GNN Forecast"],
                     ylabel=f"{args.index} index",  save_to=None)


In [None]:
args.horizon = 1
args.epochs = 25
train_and_eval(args)

#Lead months = 1, #layers = 2, Inputs between -5 to 5 latitude and 190 to 240 longitude for the past 3 months
Model will be saved in ../model/
Time series Length = 1222, Number of nodes = 33, Predict 1 time steps in advance using 3 time steps, Training set size = 1219  

The receptive field size is 19
Number of model parameters is 27969
begin training
--> Epoch   1 | time:  5.61s | train_loss 0.5047 | Val. loss 0.0356, corr 0.8565 | ONI corr 0.8606, RMSE 0.5125
Model will be saved...
--> Epoch   2 | time:  3.82s | train_loss 0.2103 | Val. loss 0.0268, corr 0.8916 | ONI corr 0.8941, RMSE 0.3738
Model will be saved...
--> Epoch   3 | time:  3.82s | train_loss 0.1282 | Val. loss 0.0208, corr 0.9420 | ONI corr 0.9468, RMSE 0.2824
Model will be saved...
--> Epoch   4 | time:  3.92s | train_loss 0.0833 | Val. loss 0.0166, corr 0.9635 | ONI corr 0.9704, RMSE 0.2170
Model will be saved...
--> Epoch   5 | time:  4.05s | train_loss 0.0645 | Val. loss 0.0147, corr 0.9708 | ONI corr 0.9778, RMSE 0