In [1]:
import argparse
import traceback
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
import models
import tasks
import utils.callbacks
import utils.data
import utils.email
import utils.logging
import torch

In [None]:
DATA_PATHS = {
    "shenzhen": {"feat": "data/sz_speed.csv", "adj": "data/sz_adj.csv"},
    "losloop": {"feat": "data/los_speed.csv", "adj": "data/los_adj.csv"},
}


def get_model(args, dm):
    model = None
    if args.model_name == "GCN":
        model = models.GCN(adj=dm.adj, input_dim=args.seq_len, output_dim=args.hidden_dim)
    if args.model_name == "GRU":
        model = models.GRU(input_dim=dm.adj.shape[0], hidden_dim=args.hidden_dim)
    if args.model_name == "TGCN":
        model = models.TGCN(adj=dm.adj, hidden_dim=args.hidden_dim)
    return model


def get_task(args, model, dm):
    task = getattr(tasks, args.settings.capitalize() + "ForecastTask")(
        model=model, feat_max_val=dm.feat_max_val, **vars(args)
    )
    return task


def get_callbacks(args):
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="train_loss")
    plot_validation_predictions_callback = utils.callbacks.PlotValidationPredictionsCallback(monitor="train_loss")
    callbacks = [
        checkpoint_callback,
        plot_validation_predictions_callback,
    ]
    return callbacks


def main_supervised(args):
    dm = utils.data.SpatioTemporalCSVDataModule(
        feat_path=DATA_PATHS[args.data]["feat"], adj_path=DATA_PATHS[args.data]["adj"], **vars(args)
    )
    model = get_model(args, dm)
    task = get_task(args, model, dm)
    callbacks = get_callbacks(args)
    trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks)
    trainer.fit(task, dm)
    results = trainer.validate(datamodule=dm)
    return results, model,task, dm  


def main(args):
    rank_zero_info(vars(args))
    results, model,task, dm = globals()["main_" + args.settings](args)
    return results, model,task, dm

In [2]:
parser = argparse.ArgumentParser()
#parser = pl.Trainer.add_argparse_args(parser)

parser.add_argument("--max_epochs", default = 100)
parser.add_argument("--pre_len", type=int, default=3)
parser.add_argument("--val_batch_size", type=int, default=1)
#parser.add_argument("--learning_rate", default = 0.001)
#parser.add_argument("--weight_decay", default = 0)
#parser.add_argument("--batch_size", default =  32)
#parser.add_argument("--hidden_dim", default = 64)
#parser.add_argument("--loss", default = "mse_with_regularizer") 
parser.add_argument("--gpus", default =  1) 


parser.add_argument(
    "--data", type=str, help="The name of the dataset", choices=("shenzhen", "losloop"), default="shenzhen"
)
parser.add_argument(
    "--model_name",
    type=str,
    help="The name of the model for spatiotemporal prediction",
    choices=("GCN", "GRU", "TGCN"),
    default="TGCN",
)
parser.add_argument(
    "--settings",
    type=str,
    help="The type of settings, e.g. supervised learning",
    choices=("supervised",),
    default="supervised",
)
parser.add_argument("--log_path", type=str, default=None, help="Path to the output console log file")
parser.add_argument("--send_email", "--email", action="store_true", help="Send email when finished")

temp_args, _ = parser.parse_known_args()

parser = getattr(utils.data, temp_args.settings.capitalize() + "DataModule").add_data_specific_arguments(parser)
parser = getattr(models, temp_args.model_name).add_model_specific_arguments(parser)
parser = getattr(tasks, temp_args.settings.capitalize() + "ForecastTask").add_task_specific_arguments(parser)

args = parser.parse_args("")
utils.logging.format_logger(pl._logger)
if args.log_path is not None:
    utils.logging.output_logger_to_file(pl._logger, args.log_path)


In [10]:
DATA_PATHS = {
    "shenzhen": {"feat": "data/sz_speed.csv", "adj": "data/sz_adj.csv"},
    "losloop": {"feat": "data/los_speed.csv", "adj": "data/los_adj.csv"},
}


def get_model(args, dm):
    model = None
    if args.model_name == "GCN":
        model = models.GCN(adj=dm.adj, input_dim=args.seq_len, output_dim=args.hidden_dim)
    if args.model_name == "GRU":
        model = models.GRU(input_dim=dm.adj.shape[0], hidden_dim=args.hidden_dim)
    if args.model_name == "TGCN":
        model = models.TGCN(adj=dm.adj, hidden_dim=args.hidden_dim)
    return model

dm = utils.data.SpatioTemporalCSVDataModule(
    feat_path=DATA_PATHS[args.data]["feat"], adj_path=DATA_PATHS[args.data]["adj"], **vars(args)
)
model = get_model(args, dm)

In [11]:
model.load_state_dict(torch.load("pretrained/tgcn_100.pth"))
model.eval()

TGCN(
  (tgcn_cell): TGCNCell(
    (graph_conv1): TGCNGraphConvolution()
    (graph_conv2): TGCNGraphConvolution()
  )
)

In [8]:
#Clone regressor
#source_task: contain attributes model (TGCN) and regressor
#dest_task: clone only regressor part. 

dest_task = tasks.supervised.SupervisedForecastTask_clone(hidden_dim = args.hidden_dim)
source_task = task

task_clone_layers = []
for item in dest_task._modules.items():
    task_clone_layers.append(item[0])
    
source_task.eval()
dest_task.eval()
for layer in task_clone_layers: 
    if layer == "regressor":
        if hasattr(getattr(dest_task, layer), 'weight'):
            with torch.no_grad():
                getattr(dest_task, layer).weight.copy_(getattr(source_task, layer).weight)
        if hasattr(getattr(dest_task, layer), 'bias'):
            with torch.no_grad():
                getattr(dest_task, layer).bias.copy_(getattr(source_task, layer).bias)
            

In [12]:
adj = dm._adj # adjacent matrix for GCN
input_dim =  adj.shape[0] 
hidden_dim = args.hidden_dim # RNN hidden state dimension

source_model = source_task.model.tgcn_cell # source model we need to clone 
source_model.eval()

# Taking the sample and the model out
probe_sample = 10
sample = 0
for val_features, val_labels in dm.val_dataloader():
    
    if sample > probe_sample:
        break
    
    if sample == probe_sample:

        inputs = val_features
        batch_size, seq_len, num_nodes = inputs.shape
        assert input_dim == num_nodes
            
    sample += 1

AttributeError: 'SpatioTemporalCSVDataModule' object has no attribute 'val_dataset'

In [16]:
dm.val_dataloader()

AttributeError: 'SpatioTemporalCSVDataModule' object has no attribute 'val_dataset'