In [None]:
import os
import sys

import json
import torch

import pandas as pd
import numpy as np
import networkx as nx


sys.path.append("../../..")
from pipelines.utils import DATASET_NAME, ROOT_DIR, TRAIN_CONFIG_KEYS, load_road_network
from road_emb import GAEModel, GATEncoder, Node2VecModel
from utils import create_pyg_data

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

city = "porto"
edge_df, _, _, LG = load_road_network(city_name=city)

In [None]:
#### Features 

# highway, length util, avg_speed, maxspeed
# Calc: travel_time
df = edge_df[["highway_enc", "lanes", "maxspeed", "length", "avg_speed", "util"]]
data = create_pyg_data(LG, df, calc_traveltime=True)
data

In [None]:
node2vec = Node2VecModel(data, device=device, q=4, p=1)
node2vec.train(epochs=200)

In [None]:
gae_model = GAEModel(data, device=device, encoder=GATEncoder, emb_dim=128)
gae_model.train(epochs=10000)

In [None]:
# Obtain embeddings and combine them
gae_embs = gae_model.load_emb()
node2vec_embs = node2vec.load_emb()

In [None]:
# Obtain embeddings and combine them
gae_embs = gae_model.load_emb()
node2vec_embs = node2vec.load_emb()

embs = np.concatenate((gae_embs, node2vec_embs), axis=1)

In [None]:
# Safe embeddings
embs_file = f"{city}_road_embs2.pkl"
embs_file = os.path.join(ROOT_DIR, "models/proposed/road_cell_embeddings", embs_file)
# pickle
import pickle
with open(embs_file, 'wb') as fh:
    pickle.dump(embs, fh, protocol = pickle.HIGHEST_PROTOCOL)

In [None]:
embs_file

In [None]:
embs.shape

In [None]:
# How are they connected?
# 

### Toast ###

In [None]:
from models.baselines.toast import SkipGramToast

In [None]:
edge_df.shape

In [None]:
input_size = edge_df.shape[0]
path = "models/states/porto/toast_segments_porto.pt"

In [None]:
_road_encoder = SkipGramToast(
   input_size, 128 # self.config["road_emb_size"]
)
_road_encoder.load_model(
    path=os.path.join(ROOT_DIR, path)
)

In [None]:
_road_emb = _road_encoder.load_emb()

In [None]:
_road_emb.shape

In [None]:
# Safe embeddings
embs_file = f"{city}_road_embs_toast.pkl"
embs_file = os.path.join(ROOT_DIR, "models/proposed/road_cell_embeddings", embs_file)
# pickle
import pickle
with open(embs_file, 'wb') as fh:
    pickle.dump(_road_emb, fh, protocol = pickle.HIGHEST_PROTOCOL)
    print("Saved to: ", embs_file)