define gym environment for RL

In [1]:
import torch
import multiprocessing
import numpy as np
import pickle
import time
from common import find_grid_idx, extract_time_feat, min_gps, max_gps, real_distance, block_number, cuda
delta_gps = max_gps - min_gps

In [2]:
pkl_folder = 'data/pkl/'
rd_data = pickle.load(open(pkl_folder + 'grid_order_reward.pkl', 'rb'))
rd_data = np.concatenate([np.expand_dims(rd_data['reward'],2), np.expand_dims(rd_data['order'],2)], axis=2)

In [3]:
action_data = pickle.load(open(pkl_folder + 'carenv_actions.pkl', 'rb'))

In [36]:
class CarEnv:
    """KDDCup2020 car environment
    State: [lat, lon, hour, average_reward, demand]
    Action: [lat, lon, ETA, reward, prob]
    
    Args:
        actions: (block_number, 24) list, every item contains actions
                 [startlat, startlon, endlat, endlon, ETA, reward, prob] with shape (x,) except prob 
                 with shape (x,10)
        reward_demand: shape (block_number, 2), means average reward and average demand    TODO: use NN to predict?
        random_seed: initial seed
        choose_ratio: select how much actions
        choose_max: select as maximum how much actions
    """
    
    def __init__(self, actions, reward_demand, random_seed = 0, choose_ratio = 0.6, choose_max = 12):
        self.actions = actions
        self.reward_demand = reward_demand
        self.rng = np.random.RandomState(random_seed)
        self.now_time = None
        self.now_state = None
        self.now_actions = None
        self.choose_ratio = choose_ratio
        self.choose_max = choose_max
        self.fail_waste = 600.0
        self.is_reset = False
        #self.reset()
        
    def _get_reward_demand(self, s):
        return self.reward_demand[int(s[0] * block_number[0]) * block_number[1] + int(s[1] * block_number[1]), s[2]]
        
    def _select_action(self):
        default_action = [np.array([self.now_state[0]]), np.array([self.now_state[1]]), 
                          np.array([self.now_state[0]]), np.array([self.now_state[1]]), 
                          np.array([0]), np.array([0]), np.array([1.0])]
        pos = (block_number * self.now_state[:2]).astype(int)
        expand = 1
        t_expand = 1
        alla = []
        for i in range(-expand, expand + 1):
            for j in range(-expand, expand + 1):
                for k in range(-t_expand, t_expand + 1):
                    k = (k + self.now_state[2] + 24) % 24
                    nowp = [i, j] + pos
                    if (nowp < 0).any() or (nowp >= block_number).any():
                        continue
                    block_idx = (nowp * [block_number[1], 1]).sum()
                    #print(block_idx, k)
                    if len(self.actions[block_idx][k][0]) > 0:
                        alla.append(self.actions[block_idx][k])
        alla = [np.concatenate(x) for x in zip(*alla)]
        if len(alla[0]) == 0:
            return [default_action]
        choose_num = int(self.rng.normal(self.choose_ratio, 2) * len(alla[0]))
        if choose_num <= 0:
            choose_num = 1
        if choose_num > len(alla[0]):
            choose_num = len(alla[0])
        if choose_num >= self.choose_max:
            choose_num = self.choose_max - 1
        choose = self.rng.choice(len(alla[0]), choose_num, replace = False)
        alla = [x[choose] for x in alla]
        gps = np.stack(alla[:2]).transpose(1, 0)
        gps = (gps - self.now_state[:2]) * real_distance
        gps = (gps ** 2).sum(axis=1) ** 0.5
        alla[4] += (gps / 5.7).astype(int) # add time driving to there
        gps = (gps / 200).astype(int)
        gps[gps > 9] = 9
        alla[-1] = np.choose(gps, alla[-1].T)
        #print(gps, alla[-1])
        for i in range(len(alla)):
            alla[i] = np.append(alla[i], default_action[i])
        return alla
    
    def reset(self):
        while True:
            self.now_time = self.rng.randint(24) * 3600
            s = [*self.rng.random(2), time.localtime(self.now_time).tm_hour]
            s += self._get_reward_demand(s).tolist()
            self.now_state = s
            a = self._select_action()
            if len(a) == 1:
                continue
            self.now_actions = a
            break
        self.is_reset = True
        return self.now_state, {'time': self.now_time}
    
    def step(self, action):
        assert(self.is_reset)
        a = [x[action] for x in self.now_actions][2:]
        if self.rng.random() < a[4]:
            reward = 0
            length = self.fail_waste # waste some time
            self.now_time += length
            self.now_state = [*self.now_state[:2], time.localtime(self.now_time).tm_hour]
            s = self.now_state
            s += self._get_reward_demand(s).tolist()
            self.now_actions = self._select_action()
            return s, reward, length, {'time': self.now_time}
        reward = a[3]
        length = a[2]
        self.now_time = self.now_time + length
        self.now_state = [*a[:2], time.localtime(self.now_time).tm_hour]
        s = self.now_state
        s += self._get_reward_demand(s).tolist()
        self.now_actions = self._select_action()
        return s, reward, length, {'time': self.now_time}
    def get_actions(self):
        assert(self.is_reset)
        return self.now_actions

In [37]:
class EnvWorker(multiprocessing.Process):
    def __init__(self, env, envargs, pipe1, pipe2):
        multiprocessing.Process.__init__(self, daemon = True)
        self.env = env(*envargs)
        self.pipe = pipe1
        self.pipe2 = pipe2
    
    def run(self):
        self.pipe2.close()
        while True:
            try:
                cmd, data = self.pipe.recv()
                if cmd == 'step':
                    self.pipe.send(self.env.step(data))
                elif cmd == 'get_actions':
                    self.pipe.send(self.env.get_actions())
                elif cmd == 'close':
                    self.pipe.close()
                    break
                elif cmd == 'reset':
                    self.pipe.send(self.env.reset())
                else:
                    raise NotImplementedError
            except EOFError:
                break
                
class EnvVecs:
    def __init__(self, env_class, n_envs, env_args, arg_seed_pos = -1, seed = 0):
        self.waiting = False
        self.closed = False

        self.remotes, self.work_remotes = zip(*[multiprocessing.Pipe(duplex=True) for _ in range(n_envs)])
        self.processes = []
        for work_remote, remote in zip(self.work_remotes, self.remotes):
            #args = (env_class, work_remote, remote)
            # daemon=True: if the main process crashes, we should not cause things to hang
            args = list(env_args)
            if arg_seed_pos != -1:
                args[arg_seed_pos] = seed
                seed += 1
            process = EnvWorker(env_class, args, work_remote, remote)  # pytype:disable=attribute-error
            process.start()
            self.processes.append(process)
            work_remote.close()
        self.is_reset = False

    def step_async(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, rews, lengths, infos = zip(*results)
        self._get_actions()
        return self._flatten_obs(obs), np.stack(rews), np.stack(lengths), self._flatten_info(infos)

    def step(self, actions):
        assert(self.is_reset)
        self.step_async(actions)
        return self.step_wait()
    
    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for process in self.processes:
            process.join()
        self.closed = True
    
    def _get_actions(self):
        for remote in self.remotes:
            remote.send(('get_actions', None))
        self.actions = [remote.recv() for remote in self.remotes]

    def get_actions(self):
        assert(self.is_reset)
        return self.actions
        
    def reset(self):
        for remote in self.remotes:
            remote.send(('reset', None))
        results = [remote.recv() for remote in self.remotes]
        obs, infos = zip(*results)
        self.is_reset = True
        self._get_actions()
        return self._flatten_obs(obs), self._flatten_info(infos)
    
    def _flatten_obs(self, obs):
        #print(obs)
        obs = list(zip(*obs))
        obs = list(map(lambda x:np.stack(x), obs))
        #print(obs)
        return obs
    
    def _flatten_info(self, info):
        if len(info) == 0:
            return {}
        res = {}
        for key in info[0].keys():
            res[key] = np.stack([x[key] for x in info])
        return res

def get_carenvvec(number):
    return EnvVecs(CarEnv, number, (action_data, rd_data, 0), 2)

Tests

In [38]:
ce = CarEnv(action_data, rd_data)

In [39]:
print(ce.reset())
print(ce.get_actions())
print(ce.step(0))

([0.5928446182250183, 0.8442657485810173, 20, 8.965, 2.0], {'time': 43200})
[array([0.5738    , 0.57934   , 0.59217   , 0.588904  , 0.612006  ,
       0.59296   , 0.575324  , 0.586206  , 0.56336   , 0.58534   ,
       0.5738    , 0.59284462]), array([0.83273333, 0.86437778, 0.84631778, 0.86904444, 0.83994222,
       0.84691111, 0.83899778, 0.86312444, 0.82107333, 0.86264444,
       0.83273333, 0.84426575]), array([0.42086   , 0.424858  , 0.42918   , 0.7751    , 0.49718   ,
       0.68324   , 0.55118   , 0.4824    , 0.57424   , 0.52422   ,
       0.52632   , 0.59284462]), array([0.57      , 0.54268444, 0.50726667, 0.44433333, 0.56582222,
       0.39708889, 0.55628889, 0.30939778, 0.43475556, 0.56164444,
       0.56731111, 0.84426575]), array([3097, 2168, 3929, 3337, 2331, 2921, 1906, 3474, 3165, 5242, 1955,
          0]), array([ 9.72, 10.05, 12.54, 15.63,  8.38, 14.06,  6.56, 18.56, 10.78,
       11.63,  6.28,  0.  ]), array([0.052844, 0.050365, 0.168158, 0.130051, 0.055499, 0.249974,


In [40]:
print(ce.get_actions())
print(ce.step(0))

[array([0.44084 , 0.43052 , 0.43584 , 0.42646 , 0.44098 , 0.45422 ,
       0.44985 , 0.42266 , 0.451182, 0.446218, 0.41524 , 0.42086 ]), array([0.54700222, 0.58884444, 0.5672    , 0.56504444, 0.54806667,
       0.54604889, 0.56584444, 0.56675111, 0.54472222, 0.54980222,
       0.54482222, 0.57      ]), array([0.5516  , 0.5164  , 0.51378 , 0.48084 , 0.49344 , 0.48656 ,
       0.52864 , 0.498402, 0.4864  , 0.50174 , 0.538624, 0.42086 ]), array([0.46753333, 0.49766667, 0.561     , 0.77008889, 0.65846667,
       0.56446667, 0.57155556, 0.57826667, 0.5712    , 0.59786667,
       0.59388444, 0.57      ]), array([1839, 1787,  760, 1198, 1083,  942, 1576,  897,  913, 1229, 1207,
          0]), array([4.01, 3.76, 2.36, 4.29, 3.4 , 1.82, 2.91, 2.45, 1.83, 2.3 , 3.49,
       0.  ]), array([0.069046, 0.038289, 0.015818, 0.019119, 0.0474  , 0.365834,
       0.03859 , 0.011601, 0.072534, 0.051171, 0.044322, 1.      ])]
([0.5516000000000076, 0.46753333333333474, 21, 3.3674456468752973, 10257.0], 4.01

In [41]:
ev = get_carenvvec(4)

In [42]:
print(ev.reset())

([array([0.59284462, 0.99718481, 0.18508208, 0.07072488]), array([0.84426575, 0.93255736, 0.93154087, 0.83994904]), array([20, 13, 16, 18]), array([ 8.965, 19.31 ,  0.   ,  0.   ]), array([2., 1., 0., 0.])], {'time': array([43200, 18000, 28800, 36000])})


In [64]:
print(ev.get_actions())
print(ev.step([0] * len(ev.processes)))

[[array([0.71686, 0.68848]), array([0.47857778, 0.49515556]), array([0.46846, 0.68848]), array([0.52877778, 0.49515556]), array([1601,    0]), array([9.4, 0. ]), array([0.037821, 1.      ])], [array([0.79738, 0.81958]), array([0.27688889, 0.2668    ]), array([0.47304, 0.81958]), array([0.54555778, 0.2668    ]), array([2435,    0]), array([14.73,  0.  ]), array([0.139955, 1.      ])], [array([0.606174, 0.633384, 0.60598 , 0.6273  , 0.636892, 0.6059  ,
       0.608738, 0.63446 , 0.61336 , 0.61886 , 0.630118, 0.63506 ]), array([0.48941778, 0.45654889, 0.48164444, 0.4726    , 0.46532444,
       0.49017778, 0.48791111, 0.45328889, 0.48666667, 0.4416    ,
       0.46366444, 0.46155556]), array([0.52158 , 0.85728 , 0.673258, 0.67582 , 0.55558 , 0.75244 ,
       0.59726 , 0.67842 , 0.57428 , 0.57066 , 0.48446 , 0.63506 ]), array([0.4604    , 0.45708889, 0.34146667, 0.5196    , 0.42435556,
       0.52944444, 0.55902222, 0.516     , 0.50526667, 0.46153333,
       0.48115556, 0.46155556]), array(