In [36]:
from iopath.common.file_io import g_pathmgr as pathmgr

import torch
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from functools import partial
import matplotlib.pyplot as plt

import model_finetune
from gmae_st.data.get_dataset import get_dataset
from gmae_st.utils import misc
from gmae_st.data.utils import DX_DICT
from data.utils import collator

In [37]:
def get_vis_dataset(
        dataset_type,
        dataset_name,
        dataset_dir,
        graph_token,
        n_hist,
        n_pred,
        norm
):
    dataset_dict = get_dataset(
        dataset_type=dataset_type,
        dataset_name=dataset_name,
        data_dir=dataset_dir,
        n_hist=n_hist,
        n_pred=n_pred,
        task='pred',
        graph_token=graph_token,
        mode='test',
        norm=norm
    )
    dataset_train = dataset_dict['train_dataset']
    dataset_test = dataset_dict['test_dataset']
    scaler = dataset_train.scaler
    sampler_test = torch.utils.data.SequentialSampler(dataset_test)

    data_sample = dataset_train[0]
    node_feature_dim = data_sample['x'].shape[-1]
    num_nodes = data_sample['adj'].shape[0]
    num_edges = len(data_sample['edge_attr'])

    # account for data.utils.collator changes
    num_spatial = torch.max(data_sample['spatial_pos']).item() + 1
    num_in_degree = torch.max(data_sample['in_degree']).item() + 1
    num_out_degree = torch.max(data_sample['out_degree']).item() + 1
    graph_info = {
        'node_feature_dim': node_feature_dim,
        'num_nodes': num_nodes,
        'num_edges': num_edges,
        'num_spatial': num_spatial,
        'num_in_degree': num_in_degree,
        'num_out_degree': num_out_degree
    }
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        sampler=sampler_test,
        batch_size=1,
        num_workers=8,
        pin_memory=True,
        drop_last=True,
        collate_fn=partial(
            collator,
            max_node=num_nodes,
            spatial_pos_max=num_spatial,
            graph_token=graph_token,
            scaler=scaler,
        ),
    )
    return data_loader_test, graph_info


def load_model(model_fp, model):
    with pathmgr.open(model_fp, 'rb') as f:
        checkpoint = torch.load(f, map_location='cpu')

    if "model" in checkpoint.keys():
        checkpoint_model = checkpoint["model"]
    else:
        checkpoint_model = checkpoint["model_state"]

    msg = model.load_state_dict(
        checkpoint_model,
        strict=False
    )
    print(msg)
    return model

In [38]:
def find_node_with_most_variance(dataloader):
    max_variance = 0
    node_with_most_variance = None

    for batch_idx, (data, target) in enumerate(dataloader):
        # data shape: [batch_size, time_steps, nodes, features]

        # Compute variance across time steps and features for each node
        # Variance is computed along the time steps and feature dimensions (dim=1 and dim=3)
        # Resulting variance shape will be [batch_size, nodes]
        variances = torch.var(data, dim=(1, 3))  # Variance over time steps and features

        # Find the maximum variance and corresponding node index for each batch
        max_batch_variance, max_batch_idx = torch.max(variances, dim=1)  # max over nodes

        # Update the global maximum if the current batch has a higher variance node
        batch_max_variance, node_idx = torch.max(max_batch_variance, dim=0)  # max over batches
        if batch_max_variance > max_variance:
            max_variance = batch_max_variance
            node_with_most_variance = data[node_idx]

    return node_with_most_variance, max_variance

In [68]:
def get_preds(
        model,
        args,
        data_loader,
):
    header = "Test:"
    device = torch.device('cpu')
    # switch to evaluation atlas
    model.eval()
    task = 'pred'
    model_pred = dict()
    for i, batch in enumerate(data_loader):
        batch = misc.prepare_batch(batch, device=device)
        scaler = None if 'scaler' not in batch else batch['scaler']

        samples, targets = batch['x'], batch['y']
        target_shape = targets.shape
        # targets = scaler.inverse_transform(targets) if scaler else targets
        if task == 'pred':
            N, P, V, D = target_shape
            if D == 2:
                targets = targets[..., [0]]

        # compute output
        with torch.cuda.amp.autocast(enabled=False):
            preds = model(batch)
            if scaler:
                outputs = scaler.inverse_transform(preds)
            else:
                outputs = preds

        ouptuts, targets = outputs.view(N * P, V), targets.view(N * P, V)
        for node in range(batch['x'].shape[2]):
            node_output = outputs[:, node]
            node_targets = targets[:, node]
            if not node in model_pred:
                model_pred[node] = dict()
                model_pred[node]['pred'] = node_output
                model_pred[node]['target'] = node_targets
            else:
                model_pred[node]['pred'] = torch.cat(
                    (model_pred[node]['pred'], node_output),
                    dim=0
                )
                model_pred[node]['target'] = torch.cat(
                    (model_pred[node]['target'], node_targets),
                    dim=0
                )
        if i == 2:
            break
    return model_pred


def visualize_pred(
        pred,
        dataset_args,
        dataset_loader,
        time_step_idx,
        node_idx_list,
):
    # Set up the figure with one subplot for each node in node_idx_list
    fig, ax = plt.subplots(len(node_idx_list), 1, figsize=(20, 12), sharex=True)
    color = ["#00798c", "#d1495b"]

    # Loop over each node in node_idx_list
    for i, node_idx in enumerate(node_idx_list):
        # Get predictions and targets for the specific node and time_step_idx
        node_preds = pred[node_idx]['pred'][time_step_idx].detach().cpu().numpy()  # Index with time_step_idx
        node_targets = pred[node_idx]['target'][time_step_idx].detach().cpu().numpy()

        # Time steps for x-axis (assumed dataset_args or similar provides information about time steps)
        time_steps = list(range(node_preds.shape[0]))

        # Plot both predictions and targets on the same subplot
        ax[i].plot(time_steps, node_preds, color=color[0], label='Predictions')
        ax[i].plot(time_steps, node_targets, color=color[1], label='Ground Truth')
        ax[i].set_title(f'Predictions vs Ground Truth for Node {node_idx} at Time Step {time_step_idx}', fontsize=16)
        ax[i].set_ylim(min(node_preds.min(), node_targets.min()) - 1, max(node_preds.max(), node_targets.max()) + 1)
        ax[i].legend(loc='upper left')

        # Remove top and right spines
        for s in ['top', 'right']:
            ax[i].spines[s].set_visible(False)

    # Shared x-label
    plt.xlabel('Time Steps', fontsize=14)

    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.show()
    return

In [69]:
def vis_pipeline(
        model_fp,
        dataset_args=None,
        time_step_idx=[0, 1, 2, 3, 4, 5],
        node_idx=[0],
        seed=0,
):
    if not dataset_args:
        n_hist, n_pred = 12, 12
        norm = True
        graph_token = False
        dataset_type = 'traffic'
        dataset_name = 'metr-la'
        dataset_dir = ''
        dataset_args = {
            'dataset_type': dataset_type,
            'dataset_dir': dataset_dir,
            'dataset_name': dataset_name,
            'n_hist': n_hist,
            'n_pred': n_pred,
            'norm': norm,
            'graph_token': graph_token,
        }
    loader, graph_info = get_vis_dataset(
        **dataset_args
    )
    dataset_args.update(graph_info)
    print(dataset_args)
    model = model_finetune.graph_causal_pred_mini(
        sep_pos_embed=False,
        cls_token=False if dataset_args['graph_token'] else True,
        end_channel=64,
        **dataset_args
    )
    model = load_model(
        model_fp=model_fp,
        model=model,
    )

    preds = get_preds(model, dataset_args, loader)
    visualize_pred(preds, dataset_args, loader, time_step_idx, node_idx)
    return preds

In [None]:
preds = vis_pipeline(
    '/Users/markbai/Documents/gmae_st/metr-la_causal-pred-mini_checkpoint-00099.pth'
)

Using normalization with mean: 54.40592829587617, std: 19.493739270573087
 > metr-la loaded!
{'train_dataset': GraphTemporalDataset(23974), 'valid_dataset': GraphTemporalDataset(3425), 'test_dataset': GraphTemporalDataset(6850)}
 > dataset info ends
{'dataset_type': 'traffic', 'dataset_dir': '', 'dataset_name': 'metr-la', 'n_hist': 12, 'n_pred': 12, 'norm': True, 'graph_token': False, 'node_feature_dim': 2, 'num_nodes': 207, 'num_edges': 22167, 'num_spatial': 12, 'num_in_degree': 155, 'num_out_degree': 155}
model initialized
<All keys matched successfully>


python(80290) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(80294) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(80297) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [50]:
preds[0]

NameError: name 'preds' is not defined