define gym environment for RL

In [1]:
import torch
import multiprocessing
import numpy as np
import pickle
import time
from models import GridModel
from common import find_grid_idx, extract_time_feat, min_gps, max_gps, real_distance, block_number, cuda
delta_gps = max_gps - min_gps

In [2]:
model = GridModel()
model.load_state_dict(torch.load('data/model/grid/best.pt')['model'])
model.eval()

GridModel(
  (houremb): Embedding(24, 8)
  (order): Sequential(
    (0): Linear(in_features=10, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=1, bias=True)
  )
  (reward): Sequential(
    (0): Linear(in_features=10, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=1, bias=True)
  )
)

In [3]:
pkl_folder = 'data/pkl/'
rd_data = pickle.load(open(pkl_folder + 'grid_order_reward.pkl', 'rb'))
meanstd = {'order': [1.3818544802263453, 2.0466071372530115], 'reward': [0.003739948797879627, 0.000964668315987685]}
for i in meanstd.keys():
    n = meanstd[i]
    if i == 'order':
        rd_data[i] = np.log(rd_data[i] + 1)
    r = rd_data[i]
    r -= n[0]
    r /= n[1]
rd_data = np.concatenate([np.expand_dims(rd_data['reward'],2), np.expand_dims(rd_data['order'],2)], axis=2)

In [4]:
action_data = pickle.load(open(pkl_folder + 'carenv_actions.pkl', 'rb'))

In [5]:
class CarEnv:
    """KDDCup2020 car environment
    State: [lat, lon, hour, average_reward, demand]
    Action: [lat, lon, ETA, reward, prob]
    
    Args:
        actions: (block_number, 24) list, every item contains actions
                 [startlat, startlon, endlat, endlon, ETA, reward, prob] with shape (x,) except prob 
                 with shape (x,10)
        reward_demand: shape (block_number, 2), means average reward and average demand    TODO: use NN to predict?
        random_seed: initial seed
        choose_ratio: select how much actions
        choose_max: select as maximum how much actions
    """
    
    def __init__(self, actions, reward_demand, random_seed = None, choose_ratio = 0.6, choose_max = 12):
        self.actions = actions
        self.reward_demand = reward_demand
        if random_seed == None:
            random_seed = np.random.randint(1 << 31)
        self.rng = np.random.RandomState(random_seed)
        self.now_time = None
        self.now_state = None
        self.now_actions = None
        self.choose_ratio = choose_ratio
        self.choose_max = choose_max
        self.fail_waste = 600.0
        self.is_reset = False
        #self.reset()
        
    def _get_reward_demand(self, s):
        if s[0] < 0 or s[1] < 0 or s[0] >= 1 or s[1] >= 1:
            return np.array([0, 0])
        return self.reward_demand[int(s[0] * block_number[0]) * block_number[1] + int(s[1] * block_number[1]), s[2]]
        
    def _select_action(self):
        default_action = [np.array([self.now_state[0]]), np.array([self.now_state[1]]), 
                          np.array([self.now_state[0]]), np.array([self.now_state[1]]), 
                          np.array([0]), np.array([0]), np.array([1.0])]
        pos = (block_number * self.now_state[:2]).astype(int)
        expand = 1
        t_expand = 1
        alla = []
        for i in range(-expand, expand + 1):
            for j in range(-expand, expand + 1):
                for k in range(-t_expand, t_expand + 1):
                    k = (k + self.now_state[2] + 24) % 24
                    nowp = [i, j] + pos
                    if (nowp < 0).any() or (nowp >= block_number).any():
                        continue
                    block_idx = (nowp * [block_number[1], 1]).sum()
                    #print(block_idx, k)
                    if len(self.actions[block_idx][k][0]) > 0:
                        alla.append(self.actions[block_idx][k])
        alla = [np.concatenate(x) for x in zip(*alla)]
        if len(alla) == 0:
            return default_action
        choose_num = int(self.rng.normal(self.choose_ratio, 2) * len(alla[0]))
        if choose_num <= 0:
            choose_num = 1
        if choose_num > len(alla[0]):
            choose_num = len(alla[0])
        if choose_num >= self.choose_max:
            choose_num = self.choose_max - 1
        choose = self.rng.choice(len(alla[0]), choose_num, replace = False)
        alla = [x[choose] for x in alla]
        gps = np.stack(alla[:2]).transpose(1, 0)
        gps = (gps - self.now_state[:2]) * real_distance
        gps = (gps ** 2).sum(axis=1) ** 0.5
        alla[4] += (gps / 5.7).astype(int) # add time driving to there
        gps = (gps / 200).astype(int)
        gps[gps > 9] = 9
        alla[-1] = np.choose(gps, alla[-1].T)
        #print(gps, alla[-1])
        for i in range(len(alla)):
            alla[i] = np.append(alla[i], default_action[i])
        return alla
    
    def _model_grid(self, args):
        with torch.no_grad():
            res = model(torch.tensor([args[:2]]), torch.tensor([args[2]]))
            #print(args, res)
            return res[1].item()
    
    def reset(self):
        while True:
            self.now_time = self.rng.randint(86400)
            s = [*self.rng.random(2), self.now_time // 3600]
            s += self._get_reward_demand(s).tolist()
            s[-1] = self._model_grid(s[:3])
            self.now_state = s
            a = self._select_action()
            if len(a) == 1:
                continue
            self.now_actions = a
            break
        self.is_reset = True
        return self.now_state, {'time': self.now_time}
    
    def step(self, action):
        assert(self.is_reset)
        a = [x[action] for x in self.now_actions][2:]
        if self.rng.random() < a[4]:
            reward = 0
            length = self.fail_waste # waste some time
            self.now_time = int(self.now_time + length) % 86400
            self.now_state = [*self.now_state[:2], self.now_time // 3600]
            s = self.now_state
            s += self._get_reward_demand(s).tolist()
            s[-1] = self._model_grid(s[:3])
            self.now_actions = self._select_action()
            return s, reward, length, {'time': self.now_time}
        reward = a[3]
        length = a[2]
        self.now_time = int(self.now_time + length) % 86400
        self.now_state = [*a[:2], self.now_time // 3600]
        s = self.now_state
        s += self._get_reward_demand(s).tolist()
        s[-1] = self._model_grid(s[:3])
        self.now_actions = self._select_action()
        return s, reward, length, {'time': self.now_time}
    def get_actions(self):
        assert(self.is_reset)
        return self.now_actions

In [6]:
class EnvWorker(multiprocessing.Process):
    def __init__(self, env, envargs, pipe1, pipe2):
        multiprocessing.Process.__init__(self, daemon = True)
        self.env = env(*envargs)
        self.pipe = pipe1
        self.pipe2 = pipe2
    
    def run(self):
        self.pipe2.close()
        while True:
            try:
                cmd, data = self.pipe.recv()
                if cmd == 'step':
                    self.pipe.send(self.env.step(data))
                elif cmd == 'get_actions':
                    self.pipe.send(self.env.get_actions())
                elif cmd == 'close':
                    self.pipe.close()
                    break
                elif cmd == 'reset':
                    self.pipe.send(self.env.reset())
                else:
                    raise NotImplementedError
            except EOFError:
                break
                
class EnvVecs:
    def __init__(self, env_class, n_envs, env_args, arg_seed_pos = -1, seed = 0):
        self.waiting = False
        self.closed = False

        self.remotes, self.work_remotes = zip(*[multiprocessing.Pipe(duplex=True) for _ in range(n_envs)])
        self.processes = []
        for work_remote, remote in zip(self.work_remotes, self.remotes):
            #args = (env_class, work_remote, remote)
            # daemon=True: if the main process crashes, we should not cause things to hang
            args = list(env_args)
            if arg_seed_pos != -1:
                args[arg_seed_pos] = seed
                seed += 1
            process = EnvWorker(env_class, args, work_remote, remote)  # pytype:disable=attribute-error
            process.start()
            self.processes.append(process)
            work_remote.close()
        self.is_reset = False

    def step_async(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, rews, lengths, infos = zip(*results)
        self._get_actions()
        return self._flatten_obs(obs), np.stack(rews), np.stack(lengths), self._flatten_info(infos)

    def step(self, actions):
        assert(self.is_reset)
        self.step_async(actions)
        return self.step_wait()
    
    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for process in self.processes:
            process.join()
        self.closed = True
    
    def _get_actions(self):
        for remote in self.remotes:
            remote.send(('get_actions', None))
        self.actions = [remote.recv() for remote in self.remotes]

    def get_actions(self):
        assert(self.is_reset)
        return self.actions
        
    def reset(self):
        for remote in self.remotes:
            remote.send(('reset', None))
        results = [remote.recv() for remote in self.remotes]
        obs, infos = zip(*results)
        self.is_reset = True
        self._get_actions()
        return self._flatten_obs(obs), self._flatten_info(infos)
    
    def _flatten_obs(self, obs):
        #print(obs)
        obs = list(zip(*obs))
        obs = list(map(lambda x:np.stack(x), obs))
        #print(obs)
        return obs
    
    def _flatten_info(self, info):
        if len(info) == 0:
            return {}
        res = {}
        for key in info[0].keys():
            res[key] = np.stack([x[key] for x in info])
        return res
    
    def __del__(self):
        self.close()

def get_carenvvec(number, seed = None):
    if seed == None:
        return EnvVecs(CarEnv, number, (action_data, rd_data))
    return EnvVecs(CarEnv, number, (action_data, rd_data, 0), 2)

Tests

In [7]:
rd = pickle.load(open(pkl_folder + 'grid_order_reward.pkl', 'rb'))
rr = rd['reward'].reshape(50, 50, 24)
sz = 5
tsz = 5
rr = np.concatenate((rr, rr[:,:,:tsz]),axis=2)
res = np.zeros((rr.shape))
for i in range(sz):
    for j in range(sz):
        for k in range(tsz):
            #print(i,j)
            res[:50 - i,:50 - j, :24 + tsz - k] += rr[i:,j:,k:]
res = res[:-sz,:-sz,:-tsz]
print((res == 0).sum(), res)

13165 [[[0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.

In [8]:
ce = CarEnv(action_data, rd_data)

In [9]:
print(ce.reset())
print(ce.get_actions())
print(ce.step(0))

([0.5283012630239232, 0.5788719042144252, 0, 1.1966354227041558, 6.207372665405273], {'time': 3584})
[array([0.54627   , 0.52830126]), array([0.54280444, 0.5788719 ]), array([0.31004   , 0.52830126]), array([0.27686667, 0.5788719 ]), array([2174,    0]), array([14.26,  0.  ]), array([0.054157, 1.      ])]
([0.31004000000001497, 0.27686666666667004, 1, 2.8977822389235333, 2.783015727996826], 14.26, 2174, {'time': 5758})


In [10]:
print(ce.get_actions())
print(ce.step(0))

[array([0.314462, 0.3189  , 0.312988, 0.317224, 0.311666, 0.314154,
       0.312664, 0.318602, 0.31423 , 0.312988, 0.32064 , 0.31004 ]), array([0.28870667, 0.28519778, 0.26415778, 0.28382667, 0.26392   ,
       0.24642667, 0.26364444, 0.28578667, 0.27022667, 0.26415778,
       0.28406444, 0.27686667]), array([0.53502 , 0.513456, 0.70576 , 0.51946 , 0.4973  , 0.55896 ,
       0.5644  , 0.52308 , 0.57596 , 0.61112 , 0.51238 , 0.31004 ]), array([0.48346667, 0.49624889, 0.51355556, 0.54691111, 0.51877778,
       0.4608    , 0.46444444, 0.5068    , 0.53071111, 0.58376   ,
       0.46942222, 0.27686667]), array([1582, 1394, 2074, 1763, 1972, 1925, 1742, 2156, 2222, 2296, 1443,
          0]), array([ 9.7 , 10.73, 16.23, 11.9 , 13.3 , 10.4 , 10.62, 12.01, 14.35,
       16.7 , 10.2 ,  0.  ]), array([0.040703, 0.048611, 0.158374, 0.290067, 0.202006, 0.0733  ,
       0.049986, 0.056614, 0.167448, 0.274834, 0.033624, 1.      ])]
([0.5350200000000029, 0.4834666666666728, 2, 1.0461142671361017, 5.74

In [11]:
ev = get_carenvvec(40)

In [12]:
print(ev.reset())

([array([0.05116192, 0.31046204, 0.720888  , 0.80825595, 0.41579263,
       0.22452482, 0.82790867, 0.53288102, 0.97993803, 0.25922624,
       0.62198563, 0.35614321, 0.03615812, 0.73729586, 0.13965611,
       0.05677904, 0.20344274, 0.88991362, 0.91117747, 0.95209474,
       0.67061634, 0.8647112 , 0.18038643, 0.52240924, 0.82116881,
       0.35957872, 0.65971114, 0.01592468, 0.69477805, 0.56920922,
       0.43823464, 0.1094505 , 0.27685393, 0.87845169, 0.15847334,
       0.51651749, 0.5635747 , 0.39904473, 0.98705714, 0.17779077]), array([0.36597161, 0.19490209, 0.71308395, 0.04424366, 0.04212024,
       0.39829165, 0.9371378 , 0.40247466, 0.40263977, 0.81452487,
       0.29475595, 0.94947447, 0.78772761, 0.10317523, 0.926348  ,
       0.20082402, 0.77654082, 0.86754822, 0.22571368, 0.78056141,
       0.49405879, 0.43107837, 0.39251091, 0.97919694, 0.34638098,
       0.88854528, 0.45570949, 0.16184251, 0.244878  , 0.47286445,
       0.04623107, 0.05156903, 0.06153155, 0.6899816 , 0.7

In [13]:
print(ev.get_actions())
print(ev.step([0] * len(ev.processes)))

[[array([0.061942  , 0.05116192]), array([0.34544444, 0.36597161]), array([0.36524   , 0.05116192]), array([0.33351111, 0.36597161]), array([3960,    0]), array([9.85, 0.  ]), array([0.032217, 1.      ])], [array([0.33836   , 0.31046204]), array([0.18048889, 0.19490209]), array([0.55816   , 0.31046204]), array([0.53824444, 0.19490209]), array([3272,    0]), array([17.55,  0.  ]), array([0.047289, 1.      ])], [array([0.74934 , 0.74534 , 0.73946 , 0.736956, 0.74434 , 0.75428 ,
       0.72028 , 0.74156 , 0.7397  , 0.70372 , 0.710592, 0.720888]), array([0.73644444, 0.7134    , 0.71875556, 0.71020444, 0.73477778,
       0.71337778, 0.73955556, 0.73106667, 0.69626667, 0.73335556,
       0.6812    , 0.71308395]), array([0.53876 , 0.6347  , 0.61816 , 0.596506, 0.65054 , 0.5644  ,
       0.62548 , 0.60128 , 0.45926 , 0.63248 , 0.61602 , 0.720888]), array([0.45031111, 0.55706667, 0.49986667, 0.48346222, 0.4982    ,
       0.45093333, 0.58942222, 0.55382222, 0.57244444, 0.45348889,
       0.5887