In [None]:
'''
This notebook uses a GRETEL model for route prediction with or without target information.
- NOTE: Prediction and evaluation is performed in two separate notebooks. This notebook is for prediction and needs to run in a pytorch-geometric environment (env_pyg)
- The input for model training needs to be formatted in a specific way. The notebook DATA_preprocess_for_GRETEL.ipynb takes care of that.
- GRETEL uses a specially formatted config file where the user can specify model hyperparameters and filepaths to the training data (more details in the train method of the Gretel class)
- specify test data to evaluate the prediction model
- specify parameters for prediction
The notebook will train the prediction model, make predictions, and save the results to file
'''

In [None]:
import pandas as pd
import json
import numpy as np
import pickle
from ast import literal_eval
import sys
import time

sys.path.append('../datawrangling')

# import modules
import dataloader_paths
from Gretel_path_prediction import GretelPathPrediction

import warnings
warnings.filterwarnings('ignore')

In [None]:
# Specify path to model
load_model = False   # load a pretrained model?
directory = '../../models/gretel_prediction_models/oslo_passenger/'
path_format = 'node2node'   # format of the training data: 'node2node' (recommended) or 'start2target'

# Specify test data
test_dates = ['202209']
selection_start = 0     # for sampling
selection_end = -1      # for sampling
selection_step = 20     # for sampling

# Specify parameters for prediction
prediction_task = 'next_nodes'  # 'next_nodes' (without destination information) or 'path' (with destination information)
n_walks = 1000                  # number of random walks for sampling
n_start_nodes = 1               # number of observed nodes (defaults to 1)
n_steps_vals = [10]             # prediction horizon (only needed for next_nodes prediction)
n_predictions = 1               # top n_predictions predictions will be output             
max_path_length = 150           # maximum length of the path to be predicted (# of subsequent nodes)

In [None]:
# Either load model from file...
if load_model:
    # load model with pickle
    network_name = '202204_waypoints_DP30_HDBSCAN25_stavanger_full_UTM'
    model_path = '../../models/gretel_prediction_models/trained_models/'+network_name+'_target.obj'
    fileObj = open(model_path, 'rb')
    model = pickle.load(fileObj)
    fileObj.close()
    with open('../../models/gretel_prediction_models/trained_models/metadata_stavanger.json', 'r') as json_file:
        meta_dict = json.load(json_file) 
    data_version = meta_dict['data_version']
    print(meta_dict)
    
# ... or train model from scratch
else:
    config_file = 'route_target'
    task = 'path'
    # load metadata file
    with open(directory+path_format+'/metadata.json', 'r') as json_file:
        meta_dict = json.load(json_file)
    network_name = meta_dict['network_name']
    data_version = meta_dict['data_version']
    filter = meta_dict['filter']
    # train model
    model = GretelPathPrediction()
    model.train(config_file, directory, task)

In [None]:
# plot training metrics
model.plot_train_test_metrics(test_only=True)

In [None]:
# Load test data from file
path_prefix = '../../data/paths/'
all_test_paths = dataloader_paths.load_path_test_data(path_prefix, network_name, test_dates, 
                                                      0, -1, 1, filter=filter, data_version=data_version)

In [None]:
for n_steps in n_steps_vals:    
    # sample test data
    if prediction_task == 'next_nodes':
        # split test paths in subpaths
        sub_paths = dataloader_paths.split_path_data(all_test_paths, n_steps+n_start_nodes)
        test_paths = dataloader_paths.sample_path_data(sub_paths, selection_start, selection_end, selection_step)
    else:
        test_paths = dataloader_paths.sample_path_data(all_test_paths, selection_start, selection_end, selection_step)
    n_test_paths=len(test_paths)
    
    
    #### MAKE PREDICTIONS ####
    start_time = time.time()
    predictions = model.predict(prediction_task, test_paths, n_start_nodes, n_steps, 
                                n_predictions, n_walks, max_path_length)
    end_time = time.time()  # end timer
    print(f'Time elapsed: {(end_time-start_time)/60:.2f} minutes')
    pps = n_test_paths/(end_time-start_time)
    print('Predictions per second: ', pps)
    
    # save results as csv
    predictions.to_csv(directory+path_format+'/predictions_'+prediction_task+str(n_steps)+'.csv')
    
    # save metadata to file
    if load_model == False:
        meta_dict['lr'] = model.config.lr
        meta_dict['loss'] = model.config.loss
        meta_dict['n_epochs'] = model.config.number_epoch
        meta_dict['target_prediction'] = model.config.target_prediction
    meta_dict['n_walks'] = n_walks
    meta_dict['n_start_nodes'] = n_start_nodes
    meta_dict['n_steps'] = n_steps
    meta_dict['prediction_task'] = prediction_task
    meta_dict['predictions_per_second'] = pps
    meta_dict['model_type'] = 'Gretel'
    meta_dict.update({'test_dates':str(test_dates),
                      'selection_start':selection_start,
                      'selection_end':selection_end,
                      'selection_step':selection_step,
                      'n_test_paths':len(test_paths)})
    with open(directory+path_format+'/metadata_'+prediction_task+str(n_steps)+'.json', 'w') as json_file:
        json.dump(meta_dict, json_file)

In [None]:
'''
import pickle
# save model as pickle object
fileObj = open('../../models/gretel_prediction_models/trained_models/'+meta_dict['network_name']+filter'.obj', 'wb')
pickle.dump(model, fileObj)
fileObj.close()
'''