In [1]:
from iopath.common.file_io import g_pathmgr as pathmgr
from functools import partial

import torch
import numpy as np
from sklearn.model_selection import (
    GridSearchCV, PredefinedSplit
)
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.preprocessing import MinMaxScaler
import xgboost as xgb
import matplotlib.pyplot as plt

from statsmodels.tsa.api import VAR

import baselines
import model_finetune
from gmae_st.data.get_dataset import get_dataset
from gmae_st.utils import misc
from gmae_st.data.utils import DX_DICT
from data.utils import collator

In [2]:
class DictArgs:
    def __init__(self, d):
        for key, value in d.items():
            setattr(self, key, value)


def get_vis_dataset(
        dataset_dir,
        graph_token,
        n_hist,
        n_pred,
        num_visits,
        filter_list,
        filter_diagnosis,
        include_pet_volume,
        norm
):
    dataset_dict = get_dataset(
        dataset_type='brain',
        dataset_name='ADNI',
        data_dir=dataset_dir,
        n_hist=n_hist,
        n_pred=n_pred,
        num_visits=num_visits,
        task='pred',
        filter_list=filter_list,
        filter_diagnosis=filter_diagnosis,
        graph_token=graph_token,
        mode='finetune',
        include_pet_volume=include_pet_volume,
        norm=norm
    )
    dataset_train = dataset_dict['train_dataset']
    dataset_val = dataset_dict['valid_dataset']
    dataset_test = dataset_dict['test_dataset']
    scaler = dataset_train.scaler
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    sampler_test = torch.utils.data.SequentialSampler(dataset_test)

    data_sample = dataset_train[0]
    node_feature_dim = data_sample['x'].shape[-1]
    num_nodes = data_sample['adj'].shape[0]
    num_edges = len(data_sample['edge_attr'])

    # account for data.utils.collator changes
    num_spatial = torch.max(data_sample['spatial_pos']).item() + 1
    num_in_degree = torch.max(data_sample['in_degree']).item() + 1
    num_out_degree = torch.max(data_sample['out_degree']).item() + 1
    graph_info = {
        'node_feature_dim': node_feature_dim,
        'num_nodes': num_nodes,
        'num_edges': num_edges,
        'num_spatial': num_spatial,
        'num_in_degree': num_in_degree,
        'num_out_degree': num_out_degree
    }
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=4,
        num_workers=8,
        pin_memory=True,
        drop_last=True,
        collate_fn=partial(
            collator,
            max_node=num_nodes,
            spatial_pos_max=num_spatial,
            graph_token=graph_token,
            scaler=scaler,
        ),
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        sampler=sampler_val,
        batch_size=4,
        num_workers=8,
        pin_memory=True,
        drop_last=True,
        collate_fn=partial(
            collator,
            max_node=num_nodes,
            spatial_pos_max=num_spatial,
            graph_token=graph_token,
            scaler=scaler,
        ),
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        sampler=sampler_test,
        batch_size=4,
        num_workers=8,
        pin_memory=True,
        drop_last=True,
        collate_fn=partial(
            collator,
            max_node=num_nodes,
            spatial_pos_max=num_spatial,
            graph_token=graph_token,
            scaler=scaler,
        ),
    )
    return data_loader_train, data_loader_val, data_loader_test, graph_info


def prepare_data(data_loader, args):
    """
    Prepares the input and target data for sklearn.
    Flattens the input tensor and converts it into numpy arrays for sklearn compatibility.
    """
    all_samples = []
    all_targets = []
    device = torch.device('cpu')
    for batch in data_loader:
        batch = misc.prepare_batch(batch, device=device)
        samples, targets, target_shape = misc.get_samples_targets(
            batch=batch,
            task='pred',
            device=device,
            args=args
        )

        assert samples.shape[0] == targets.shape[
            0], f'batch size of samples {samples.shape[0]} does not match targets {targets.shape[0]}'
        assert samples.shape[-1] == targets.shape[
            -1], f'feature dimension of samples {samples.shape[-1]} does not match targets {targets.shape[-1]}'

        _, T, _, _ = samples.shape
        N, P, V, D = target_shape
        samples = samples.view(N * T, V * D).numpy()  # Flatten samples for sklearn
        targets = targets.view(N, P * V * D).numpy()  # Flatten targets

        all_samples.append(samples)
        all_targets.append(targets)

    # Combine all batches into single arrays
    all_samples = np.vstack(all_samples)
    all_targets = np.vstack(all_targets)

    return all_samples, all_targets

In [3]:
# Define default parameter grids for each model
PARAM_GRIDS = {
    'Linear Regression': {},  # No hyperparameters to tune for basic LinearRegression
    'Random Forest': {
        'n_estimators': [50, 100, 200],
        'max_depth': [None, 10, 20, 30],
        'min_samples_split': [2, 5, 10]
    },
    'Support Vector Regression': {
        'C': [0.01, 0.1, 1, 10],
        'kernel': ['linear', 'rbf']
    },
    'K-Nearest Neighbors': {
        'n_neighbors': [3, 5, 7],
        'weights': ['uniform', 'distance']
    },
    'Decision Tree': {
        'max_depth': [None, 10, 20, 30],
        'min_samples_split': [2, 5, 10]
    },
    'Gradient Boosting': {
        'n_estimators': [50, 100, 200],
        'learning_rate': [0.01, 0.1, 0.5],
        'max_depth': [3, 5, 7]
    },
    'XGBoost': {
        'n_estimators': [50, 100, 200],
        'learning_rate': [0.01, 0.1, 0.5],
        'max_depth': [3, 5, 7]
    },
}


def evaluate_metrics(target, output):
    metrics = {}

    # Mask out node values that are zero to avoid numerical errors in MAPE
    mask = (target != 0)

    # Calculate MAE
    mae = np.abs(output - target)
    mae = np.where(mask, mae, 0)  # Apply mask
    metrics['MAE'] = np.mean(mae)

    # Calculate RMSE
    rmse = np.square(output - target)
    rmse = np.where(mask, rmse, 0)  # Apply mask
    metrics['RMSE'] = np.sqrt(np.mean(rmse))

    # Calculate MAPE
    mape = np.abs((output - target) / target)
    mape = np.where(mask, mape, 0)  # Apply mask
    metrics['MAPE'] = np.mean(mape)

    return metrics


def grid_search(
        X_train,
        y_train,
        X_val,
        y_val,
        X_test,
        y_test,
        model_hyperparameters=None,
        seed=0,
):
    # Combine train and validation data
    X_trainval = np.concatenate([X_train, X_val], axis=0)
    y_trainval = np.concatenate([y_train, y_val], axis=0)

    # Create a test_fold array: -1 for training samples, 0 for validation samples
    test_fold = np.hstack([np.full(X_train.shape[0], -1), np.full(X_val.shape[0], 0)])

    # PredefinedSplit object
    ps = PredefinedSplit(test_fold)

    # Initialize models
    models = {
        'Linear Regression': LinearRegression(),
        'Random Forest': RandomForestRegressor(),
        'Support Vector Regression': SVR(),
        'K-Nearest Neighbors': KNeighborsRegressor(),
        'Decision Tree': DecisionTreeRegressor(),
        'Gradient Boosting': GradientBoostingRegressor(),
        'XGBoost': xgb.XGBRegressor(),
    }

    # GridSearchCV results storage
    grid_results = []
    # Load model hyperparameters
    param_grids = PARAM_GRIDS
    if model_hyperparameters:
        param_grids = model_hyperparameters

    # Train and evaluate models using GridSearchCV and KFold
    for model_name, model in models.items():
        if model_name in param_grids:
            param_grid = param_grids[model_name]
            grid_search = GridSearchCV(
                estimator=model,
                param_grid=param_grid,
                cv=ps,
            )

            # Fit model using GridSearchCV
            grid_search.fit(X_trainval, y_trainval)

            # Get best model from grid search
            best_model = grid_search.best_estimator_

            # Make predictions on the test set
            y_test_pred = best_model.predict(X_test)

            # Calculate MSE, RMSE
            metrics = evaluate_metrics(y_test, y_test_pred)
            mae, rmse, mape = metrics['MAE'], metrics['RMSE'], metrics['MAPE']
            # Store results
            grid_results.append({
                'Model': model_name,
                'Best Params': grid_search.best_params_,
                'Test MAE': mae,
                'Test RMSE': rmse,
                'Test MAPE': mape
            })

            # Output best parameters and results for each model
            print(f'Best parameters for {model_name}: {grid_search.best_params_}')
            print(f'{model_name} - Test MAE: {mae:.4f}, Test RMSE: {rmse:.4f}, Test MAPE: {mape:.4f}')
        else:
            print(f'No parameters found for {model_name}.')

In [4]:
n_hist, n_pred, num_visits = 1, 2, 3
filter_list = (0, 1, 0)
filter_diagnosis, include_pet_volume = False, False
norm = True
graph_token = False
dataset_dir = ''
dataset_args = {
    'dataset_dir': dataset_dir,
    'n_hist': n_hist,
    'n_pred': n_pred,
    'num_visits': num_visits,
    'filter_list': filter_list,
    'filter_diagnosis': filter_diagnosis,
    'include_pet_volume': include_pet_volume,
    'norm': norm,
    'graph_token': graph_token,
}
loader_train, loader_val, loader_test, graph_info = get_vis_dataset(
    **dataset_args
)
dataset_args.update(graph_info)
dataset_args['dataset_type'] = 'brain'
print(dataset_args)

Getting ADNI data with 
num_visits: 3, 
Amyloid-Beta PET scans
total subjects: 330
train subjects:  231 val subjects:  33 test subjects:  66
Using normalization with mean: [1.12399331], std: [0.23197721]
 > ADNI loaded!
{'train_dataset': GraphTemporalDataset(398), 'valid_dataset': GraphTemporalDataset(55), 'test_dataset': GraphTemporalDataset(120), 'class_init_prob': None}
 > dataset info ends
{'dataset_dir': '', 'n_hist': 1, 'n_pred': 2, 'num_visits': 3, 'filter_list': (0, 1, 0), 'filter_diagnosis': False, 'include_pet_volume': False, 'norm': True, 'graph_token': False, 'node_feature_dim': 1, 'num_nodes': 68, 'num_edges': 1394, 'num_spatial': 19, 'num_in_degree': 42, 'num_out_degree': 42, 'dataset_type': 'brain'}


In [None]:
x_train, y_train = prepare_data(loader_train, DictArgs(dataset_args))
x_val, y_val = prepare_data(loader_val, DictArgs(dataset_args))
x_test, y_test = prepare_data(loader_test, DictArgs(dataset_args))

In [None]:
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape, x_test.shape, y_test.shape)
x_train[0][0], y_train[0][0]

In [None]:
grid_search(
    x_train, y_train,
    x_val, y_val,
    x_test, y_test,
)