In [None]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import BaseCallback
from tmrl import get_environment
from time import sleep

In [None]:
import os    
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [None]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [None]:
CHECK_FREQ_NUMB = 10000
TOTAL_TIMESTEP_NUMB = 300000
LEARNING_RATE = 0.0003
N_STEPS = 1024
GAMMA = 0.995
BATCH_SIZE = 256
N_EPOCHS = 10
DOWN_SAMPLE_RATE = 3
SKIP_NUMB = 2
EPISODE_NUMBERS = 10

In [None]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [None]:
callback = TrainAndLoggingCallback(check_freq=CHECK_FREQ_NUMB, save_path=CHECKPOINT_DIR)

In [None]:
from gym import Env
from gym.spaces import Box, MultiDiscrete

In [None]:
class TrackMania(Env):
    def __init__(self):
        super().__init__()
        self.game = get_environment()
        sleep(1.0)
        self.observation_space = Box(low=0.0, high=float('inf'), shape=(4,19), dtype=np.float32) 
        self.action_space = MultiDiscrete([3,3,3])
        self.comple = 0.0000000
        self.speed = 0.0000000
    
    def step(self, action):
        act = np.array(action, dtype=np.float32)
        for i in range(3):
            act[i] -=1
        obs, reward, done, info = self.game.step(act)
#         curr_comple = obs[1][0]
        reward = reward/10
        reward += int((obs[1][0]-self.comple)*20)
#         state = obs[2]
#         curr_speed = obs[0][0]
        if self.speed-obs[0][0]>=5:
            reward -=500

#         reward += int(obs[0][0]/10)
        self.comple = obs[1][0]
        self.speed = obs[0][0]
        return obs[2], reward, done, info
    
    def render():
        pass
    
    def reset(self):
        obs = self.game.reset()
        self.comple = 0.0000000
        self.speed = 0.0000000
        state = obs[2]
        return state
    
    def wait(self):
        self.game.wait()

In [None]:
# def model(obs):
#     """
#     simplistic policy
#     """
#     deviation = obs[2].mean(0)
#     deviation /= (deviation.sum() + 0.001)
#     steer = 0
#     for i in range(19):
#         steer += (i - 9) * deviation[i]
#     steer = - np.tanh(steer * 4)
#     steer = min(max(steer, -1.0), 1.0)
#     return np.array([1.0, 0.0, steer])

In [None]:
# env = get_environment()  # retrieve the TMRL Gym environment
# sleep(1.0)  # just so we have time to focus the TM20 window after starting the script

In [None]:
env = TrackMania()

In [None]:
obs = env.reset() # reset environment

In [None]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=LOG_DIR, learning_rate=LEARNING_RATE, n_steps=N_STEPS,
            batch_size=BATCH_SIZE, n_epochs=N_EPOCHS, gamma=GAMMA)

In [None]:
%%time
model.learn(total_timesteps=TOTAL_TIMESTEP_NUMB, callback=callback)

In [None]:
obs = env.reset()

In [None]:
obs

In [None]:
mean_score = 0.0
mean_time = 0.0
for i in range(10):
    t_st = 0
    score = 0
    obs = env.reset()
    while True:  # rtgym ensures this runs at 20Hz by default
        act, _ = model.predict(obs)  # compute action
        obs, rew, done, info = env.step(act)  # apply action (rtgym ensures healthy time-steps)
        score +=rew
        t_st +=1
        if done:
            break
    print(f"score: {score}, time: {t_st}")
    mean_score+=score
    mean_time+=t_st
print(f"m_score: {mean_score/10}, m_time: {mean_time/10}")

In [None]:
model = PPO.load('./train/best_model_600000.zip')