This file generates the data used in the Flatland+ experiments. 

In [None]:
# Imports, parameters, etc. 
from flatland_plus_environment import FlatLandPlus
from tqdm import trange
import numpy as np
import matplotlib.pyplot as plt
MAX_EP_LENGTH = 200
TOTAL_SAMPLES = 100_000 # OGBench had 1 million but this seems more reasonable for this
HOLDOUT_SPLIT = 0.3333
RANDOM_SEED = 42
save_directory = 'flatland_data'
N_list = [2]
SIZE = 5

In [None]:
from flatland_plus_environment import FlatLandPlus
from tqdm import trange
import numpy as np
import matplotlib.pyplot as plt
for N in N_list:
    # Collects and shuffles all possible Start-Goal pairs, sorts into 
    # training and testing
    sg_pairs = []
    n_facets = N * 2 
    for start_index in range(n_facets):
        for goal_index in range(n_facets):
            if start_index != goal_index:
                sg_pairs.append(tuple((start_index,goal_index)))
    rng = np.random.default_rng(seed=RANDOM_SEED)
    rng.shuffle(sg_pairs)
    training_split = sg_pairs[:1-int(HOLDOUT_SPLIT*len(sg_pairs))]
    testing_split = sg_pairs[1-int(HOLDOUT_SPLIT*len(sg_pairs)):]
    assert len(sg_pairs) == len(training_split) + len(testing_split)
    #print('train',training_split)
    #print('test',testing_split)
    # some setup stuff
    timeout = MAX_EP_LENGTH
    env = FlatLandPlus(n_dims = N)
    n_sg_pairs = int(n_facets * (n_facets-1))
    rng = np.random.default_rng(seed = 42)
    gaussian_scale = 0.05
    obs_record = []
    actions = []
    terminals = []
    sg_index = 0
    for rollout in range(int(TOTAL_SAMPLES/MAX_EP_LENGTH)):
        # setup
        (start_index, goal_index) = training_split[sg_index]
        been_to_middle = False
        obs, info = env.reset(start_idx=start_index,goal_idx=goal_index)
        start = obs
        goal = info['goal']
        middle = np.zeros_like(start)
        truncated = False
        # overall environment interaction loop
        while not truncated:
            obs_record.append(obs)
            if not been_to_middle: # travelling from start to middle
                # base action going towards middle
                action = middle-obs
                # adds small gaussian noise 
                action = action + rng.normal(loc=0,scale=gaussian_scale,size=action.shape)
            else: # travelling from middle to goal
                # base action going towards goal
                action = goal-obs
                # adds small gaussian noise 
                action = action + rng.normal(loc=0,scale=gaussian_scale,size=action.shape)
            # executes action
            obs, _ , _, truncated, _ = env.step(action)
            if np.linalg.norm(obs) < env.tolerance:
                been_to_middle = True
            actions.append(action)
            terminals.append(truncated)
        # sg index handling
        sg_index +=1
        if sg_index >= len(training_split):
            sg_index = 0
    obs_record, actions, terminals = np.array(obs_record), np.array(actions), np.array(terminals)
    print(obs_record.shape)
    save_dict = {'observations':obs_record,'terminals':terminals,'actions':actions}
    np.save(save_directory+'/dataset_n'+str(N) +'.npy',save_dict)``

(100000, 2)


In [3]:
# this cell just prints the train/test splits for each so you
# can easily grab some of the testing splits for a given dataset
for N in N_list:
    # Collects and shuffles all possible Start-Goal pairs, sorts into 
    # training and testing
    sg_pairs = []
    n_facets = N * 2 
    for start_index in range(n_facets):
        for goal_index in range(n_facets):
            if start_index != goal_index:
                sg_pairs.append(tuple((start_index,goal_index)))
    rng = np.random.default_rng(seed=RANDOM_SEED)
    rng.shuffle(sg_pairs)
    training_split = sg_pairs[:1-int(HOLDOUT_SPLIT*len(sg_pairs))]
    testing_split = sg_pairs[1-int(HOLDOUT_SPLIT*len(sg_pairs)):]
    assert len(sg_pairs) == len(training_split) + len(testing_split)
    print('train',training_split)
    print('test',testing_split)

train [(0, 1), (2, 1), (2, 0), (3, 0), (3, 2), (1, 0), (1, 3), (0, 3), (1, 2), (3, 1)]
test [(0, 2), (2, 3)]
