In [1]:
import h5py
import networkx as nx
from tqdm import tqdm

In [2]:
import sys
sys.path.append('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction')

from dataset.trajectory_dataset_geometric import TrajectoryGeoDataset

In [9]:
def train_val_test_split(file_path, train_ratio=0.85, val_ratio=0.05, test_ratio=0.1, save=False):
    """Create a train, validation, and test split for the given dataset.

    Args:
        file_path (str): tdrive, geolife, pNEUMA_filtered, or munich
        train_ratio (float, optional): Defaults to 0.85.
        val_ratio (float, optional): Defaults to 0.05.
        test_ratio (float, optional): Defaults to 0.1.
        save (bool, optional): Defaults to False.
    """
    import numpy as np
    ct = 0
    print("Load Data...")
    paths, nodes, edges, edge_coordinates = TrajectoryGeoDataset.load_new_format(file_path, [''], device='cpu')
    indexed_edges = [((start, end), index) for index, (start, end) in enumerate(edges)]
    G = nx.Graph()
    G.add_nodes_from(nodes)
    for (start, end), index in indexed_edges:
        G.add_edge(start, end, index=index, default_orientation=(start, end))
    if 'tdrive' in file_path.lower():
        dataset = 'tdrive'
    elif 'geolife' in file_path.lower():
        dataset = 'geolife'
    elif 'pneuma' in file_path.lower():
        dataset = 'pneuma'
    elif 'munich' in file_path.lower():
        dataset = 'munich'
    else:
        raise ValueError('Unknown dataset')
    
    n = len(paths)
    print("Dataset: ", dataset)
    print("Total number of paths: ", n)
    train_size = int(n * train_ratio)
    val_size = int(n * val_ratio)
    test_size = n - train_size - val_size

    train_paths = paths[:train_size]
    val_paths = paths[train_size:train_size + val_size]
    test_paths = paths[train_size + val_size:]
    
    if save:
        # Save the data
        print("Saving the data...")
        output_folder = '/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/'
        for output_file_path, split_paths in zip([f'{output_folder}{dataset}_train.h5', f'{output_folder}{dataset}_val.h5', f'{output_folder}{dataset}_test.h5'], [train_paths, val_paths, test_paths]):
            with h5py.File(output_file_path, 'w') as f:
                # Save graph structure
                grp_graph = f.create_group('graph')
                grp_graph.create_dataset('node_coordinates', data=[list(pos['pos']) for _, pos in nodes])
                grp_graph.create_dataset('edges', data=np.array(edges))
                #grp_graph.create_dataset('road_type', data=road_type)

                # Save the selected trajectories
                grp_trajectories = f.create_group('trajectories')
                for i, path in enumerate(split_paths):
                    # Check if datapoint is valid
                    # 1. Check if path is connected
                    edges_tmp = [indexed_edges[idx][0] for idx in path['edge_idxs']]  # Adjust this if your graph structure differs

                    # Create a subgraph from these edges
                    subgraph = nx.Graph()
                    subgraph.add_edges_from(edges_tmp)
                    connected = nx.is_connected(subgraph)
                    
                    # 2. Check if path is acyclic
                    if len(path['edge_idxs']) <= 1:
                        acyclic = True
                    else:
                        subgraph_nodes = []
                        subgraph_edges = []
                        for idx in path['edge_idxs']:
                            edge = indexed_edges[idx][0]  # get the node tuple for each edge
                            subgraph_nodes.append(edge)
                            subgraph_edges.append(edge)
                        subgraph = nx.Graph()
                        subgraph.add_edges_from(subgraph_edges)
                        has_cycle = nx.cycle_basis(subgraph)
                        acyclic = len(has_cycle) == 0
                    
                    # 3. Check if path does not contain splits
                    if len(path['edge_idxs']) <= 1:
                        no_splits = True
                    else:
                        subgraph_nodes = set()
                        subgraph_edges = []
                        for idx in path['edge_idxs']:
                            edge = indexed_edges[idx][0]  # get the node tuple for each edge
                            subgraph_nodes.update(edge)
                            subgraph_edges.append(edge)
                        # Create a directed version of the subgraph to check for cycles
                        subgraph = nx.Graph()
                        subgraph.add_nodes_from(subgraph_nodes)
                        subgraph.add_edges_from(subgraph_edges)
                        if any(subgraph.degree(node) > 2 for node in subgraph.nodes()):
                            no_splits = False
                        else:
                            no_splits = True
                    
                    if (connected and acyclic and no_splits ):
                        grp = grp_trajectories.create_group(f'trajectory_{i}')
                        print(grp)
                        for key, value in path.items():
                            grp.create_dataset(key, data=value)
                    else:
                        ct += 1
        print("Data saved!")
        print(f"Number of invalid paths: {ct}")
    else:
        return train_paths, val_paths, test_paths

In [None]:
train_val_test_split('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/tdrive.h5', save=True)

In [None]:
# Define the train_val_test_regional_split function
def train_val_test_regional_split(file_path, val_coords: list, test_coords: list, save=False):
    import numpy as np
    print("Load Data...")
    paths, nodes, edges, edge_coordinates = TrajectoryGeoDataset.load_new_format(file_path)
    if 'tdrive' in file_path.lower():
        dataset = 'tdrive'
    elif 'geolife' in file_path.lower():
        dataset = 'geolife'
    elif 'pneuma' in file_path.lower():
        dataset = 'pneuma'
    elif 'munich' in file_path.lower():
        dataset = 'munich'
    else:
        raise ValueError('Unknown dataset')
    
    # Get the indices of the paths that have 'coordinates' only in the val_coords range
    val_indices = [i for i, path in enumerate(paths) if all((path['coordinates'][:, 0] >= val_coords[0][0]) & (path['coordinates'][:, 0] <= val_coords[0][1]) & (path['coordinates'][:, 1] >= val_coords[1][0]) & (path['coordinates'][:, 1] <= val_coords[1][1]))]
    
    # Get the indices of the paths that have 'coordinates' only in the test_coords range
    test_indices = [i for i, path in enumerate(paths) if all((path['coordinates'][:, 0] >= test_coords[0][0]) & (path['coordinates'][:, 0] <= test_coords[0][1]) & (path['coordinates'][:, 1] >= test_coords[1][0]) & (path['coordinates'][:, 1] <= test_coords[1][1]))]
    
    # Get the validation paths
    val_paths = []
    for i in (val_indices):
        val_paths.append(paths[i])
    
    # Get the test paths
    test_paths = []
    for i in test_indices:
        test_paths.append(paths[i])
    
    # Get the train paths
    train_indices = [i for i in range(len(paths)) if i not in val_indices and i not in test_indices]
    train_paths = []
    for i in train_indices:
        train_paths.append(paths[i])
    
    # Print the ratios
    print("Dataset: ", dataset)
    print("Total number of paths: ", len(paths))
    print("Ratio of training paths: ", len(train_paths) / len(paths))
    print("Ratio of validation paths: ", len(val_paths) / len(paths))
    print("Ratio of testing paths: ", len(test_paths) / len(paths))
    import matplotlib.pyplot as plt

    # Create a figure and axis
    fig, ax = plt.subplots()

    # Plot the rectangle
    rectangle = plt.Rectangle((0, 0), 1, 1, edgecolor='black', facecolor='grey')
    ax.add_patch(rectangle)

    # Plot the val area
    val_rectangle = plt.Rectangle((val_coords[0][0], val_coords[1][0]), 
                                  val_coords[0][1] - val_coords[0][0], 
                                  val_coords[1][1] - val_coords[1][0], 
                                  edgecolor='red', facecolor='red')
    ax.add_patch(val_rectangle)
    plt.text((val_coords[0][0] + val_coords[0][1])/2, (val_coords[1][0] + val_coords[1][1])/2, 'Validation', color='white', ha='center', va='center')

    # Plot the test area
    test_rectangle = plt.Rectangle((test_coords[0][0], test_coords[1][0]), 
                                   test_coords[0][1] - test_coords[0][0], 
                                   test_coords[1][1] - test_coords[1][0], 
                                   edgecolor='blue', facecolor='blue')
    ax.add_patch(test_rectangle)
    plt.text((test_coords[0][0] + test_coords[0][1])/2, (test_coords[1][0] + test_coords[1][1])/2, 'Test', color='White', ha='center', va='center')


    # Set the aspect ratio to equal
    ax.set_aspect('equal')

    # Set the x and y limits
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # Show the plot
    plt.show()
    
    if save:
        # Save the data
        print("Saving the data...")
        output_folder = '/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/'
        for output_file_path, split_paths in zip([f'{output_folder}{dataset}_coordinate_split_train.h5', 
                                                f'{output_folder}{dataset}_coordinate_split_x_{val_coords[0][0]}_{val_coords[0][1]}_y_{val_coords[1][0]}_{val_coords[1][1]}_val.h5', 
                                                f'{output_folder}{dataset}_coordinate_split_x_{test_coords[0][0]}_{test_coords[0][1]}_y_{test_coords[1][0]}_{test_coords[1][1]}_test.h5'], 
                                                [train_paths, val_paths, test_paths]):
            with h5py.File(output_file_path, 'w') as f:
                # Save graph structure
                grp_graph = f.create_group('graph')
                grp_graph.create_dataset('node_coordinates', data=[list(pos['pos']) for _, pos in nodes])
                grp_graph.create_dataset('edges', data=np.array(edges))

                # Save the selected trajectories
                grp_trajectories = f.create_group('trajectories')
                for i, path in enumerate(split_paths):
                    grp = grp_trajectories.create_group(f'trajectory_{i}')
                    print(grp)
                    for key, value in path.items():
                        grp.create_dataset(key, data=value)
        print("Data saved!")
                        
    else:
        return train_paths, val_paths, test_paths, nodes, edges, edge_coordinates

In [None]:
train_val_test_regional_split('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/pNEUMA_filtered.h5', [[0.1, 0.25], [0.4, 0.55]], [[0.6, 1.0], [0.0, 0.4]], save=True)

In [None]:
train_val_test_regional_split('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/geolife.h5', [[0.1, 0.25], [0.4, 0.55]], [[0.6, 1.0], [0.0, 0.4]], save=True)

In [None]:
train_val_test_regional_split('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/munich.h5', [[0.1, 0.25], [0.4, 0.55]], [[0.6, 1.0], [0.0, 0.4]], save=True)