任务数据处理相关脚本

In [11]:
import yaml
import neorl2
import random
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from collections import OrderedDict

模拟生成csv数据

In [6]:
env = neorl2.make("RocketRecovery")
data,_ = env.get_dataset()

obs = data["obs"]
next_obs = data["next_obs"]
action = data["action"]
index =np.where(np.logical_or(data["done"], data["truncated"]))[0]+1
index = [0,] + list(index)

rocket_state = []
engine_power = []
wind_speed = []

id = 1
for start_index,end_index in zip(index[:-1],index[1:]):
    state = np.concatenate([obs[start_index:end_index],next_obs[end_index-1:end_index]],axis=0)
    engine_power = np.concatenate([action[start_index:end_index],action[end_index-1:end_index]],axis=0)
    
    columns = ["X-axis","Y-axis","X-velocity","Y-velocity","angle","angular_velocity","wind_speed","main_engine_power","left_right_engine_power"]
    df = pd.DataFrame(np.concatenate([state, engine_power],axis=1),columns=columns)

    start_date = datetime(2010, 3, 1) 
    end_date = datetime(2022, 3, 31) 
    start_time = datetime.combine(random.choice(pd.date_range(start=start_date, end=end_date)).date(), datetime.min.time())
    
    df['time'] = [start_time + timedelta(seconds=1)*i for i in range(len(df))]
    timestamp_col = df.pop('time')
    df.insert(0, 'time', timestamp_col)
    
    df.to_csv(f"./csv/{id}.csv")
    id += 1

生成决策流图

In [15]:
config_dict = {}
config_dict['metadata'] = dict()

graph ={
    "engine_power": ["rocket_state", "wind_speed"],
    "next_rocket_state": ["rocket_state","wind_speed","engine_power"],
}

config_dict['metadata']['graph'] = graph

nodes = ["rocket_state","wind_speed","engine_power"]
config_dict['metadata']['columns'] = []

for node in nodes:
    if node == "rocket_state":
        for column in ["X-axis","Y-axis","X-velocity","Y-velocity","angle","angular_velocity"]:
            config_dict['metadata']['columns'].append({
                f'{column}' : {
                    'type' : 'continuous',
                    'dim' : node
                }
            })
    elif node == "wind_speed":
        for column in ["wind_speed",]:
            config_dict['metadata']['columns'].append({
                f'{column}' : {
                    'type' : 'continuous',
                    'dim' : node
                }
            })
    elif node == "engine_power":
        for column in ["main_engine_power","left_right_engine_power"]:
            config_dict['metadata']['columns'].append({
                f'{column}' : {
                    'type' : 'continuous',
                    'dim' : node
                }
            })
    else:
        raise NotImplementedError
        
        
with open("../revive_train/data/rocketrecovery.yaml", 'w') as f:
    yaml.dump(dict(config_dict), f)