In [1]:
from env.env_simple_move import HumanMoveSimpleAction
import common.mlflow_sb3_helper as mlf

import torch
import numpy as np
from stable_baselines3 import DDPG, SAC, TD3, DQN
from stable_baselines3.common.noise import NormalActionNoise

import os

In [2]:
env = HumanMoveSimpleAction()
env_disc = HumanMoveSimpleAction(continuous=False)
mlflow_server = mlf.MLflowServerHelper("http://192.168.0.206:2670")

from pytz import timezone
from datetime import datetime

TZ = timezone('Europe/Moscow')
models = {'DQN':DQN, 'DDPG':DDPG, 'SAC':SAC, 'TD3':TD3}
activations = {'ReLU':torch.nn.ReLU, 'Tanh':torch.nn.Tanh}
device = "cuda" if torch.cuda.is_available() else "cpu"

mlflow_path = {}

In [3]:
exp_params = {
    'env_name': env.name(),
    'algorithm_name': '',
    'exp_id': f'exp_{datetime.now(TZ).strftime("%d%m_%H%M%S")}',
    'seed': 21,
    'net': {
        'activation': 'ReLU',
        'pi': [256, 256],
        'qf': [256, 256],
        'vf': [256, 256],
    },
    'training': {
        'iteration_count': 1,
        'episode_count': 50000,
        'policy': 'MlpPolicy',
        'learning_rate': 0.003,
        'buffer_size': 1000000,
        'learning_starts': 100,
        'batch_size': 256,
        'tau': 0.005,
        'gamma': 0.99,
        'verbose': 0,
        'device': device,
    },
    'validation': {
        'validate_agent_every_n_eps': 20000,
        'log_interval': 100
    },
    'evaluation': {
        'episode_count': 1
    }
}

policy_kwargs = dict(activation_fn = activations[exp_params['net']['activation']],
                     net_arch = dict(pi=exp_params['net']['pi'], qf=exp_params['net']['qf'], vf=exp_params['net']['vf']))

exp_name = 'env_' + exp_params['env_name'] + '_' + exp_params['exp_id']
experiment_id = mlflow_server.new_experiment(exp_name)


2024/09/02 12:29:39 INFO mlflow.tracking.fluent: Experiment with name 'env_HumanMoveSimple_exp_0209_122939' does not exist. Creating a new experiment.


In [4]:
i = 0
for name, m in models.items():
    print(name)
    exp_params['algorithm_name'] = name
    exp_params['seed'] = int(datetime.now(TZ).strftime("%H%M%S"))

    model = 0
    if name == 'TD3':
        n_actions = env.action_space.shape[-1]
        action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
        model = m(  exp_params['training']['policy'], 
                    env,
                    #policy_kwargs=policy_kwargs,
                    learning_rate=exp_params['training']['learning_rate'],
                    buffer_size=exp_params['training']['buffer_size'],
                    learning_starts=exp_params['training']['learning_starts'],
                    batch_size=exp_params['training']['batch_size'],
                    tau=exp_params['training']['tau'],
                    gamma=exp_params['training']['gamma'],
                    verbose=exp_params['training']['verbose'],
                    action_noise=action_noise,
                    device=device,
                    )
    if name == 'DQN':
        model = m(  exp_params['training']['policy'], 
                    env_disc,
                    #policy_kwargs=policy_kwargs,
                    learning_rate=exp_params['training']['learning_rate'],
                    buffer_size=exp_params['training']['buffer_size'],
                    learning_starts=exp_params['training']['learning_starts'],
                    batch_size=exp_params['training']['batch_size'],
                    tau=exp_params['training']['tau'],
                    gamma=exp_params['training']['gamma'],
                    verbose=exp_params['training']['verbose'],
                    device=device,
                    )
    else:
        model = m(  exp_params['training']['policy'], 
                    env,
                    #policy_kwargs=policy_kwargs,
                    learning_rate=exp_params['training']['learning_rate'],
                    buffer_size=exp_params['training']['buffer_size'],
                    learning_starts=exp_params['training']['learning_starts'],
                    batch_size=exp_params['training']['batch_size'],
                    tau=exp_params['training']['tau'],
                    gamma=exp_params['training']['gamma'],
                    verbose=exp_params['training']['verbose'],
                    device=device,
                    )

    art_loc, exp_name, run_id = mlflow_server.learn_and_fix(
                                model=model, 
                                env= HumanMoveSimpleAction(continuous=False, render_mode='rgb_array') if name == 'DQN'else HumanMoveSimpleAction(render_mode='rgb_array'),
                                run_name=name,
                                episode_count = exp_params['training']['episode_count'],
                                parameters=exp_params,
                                experiment_id=experiment_id,
                                checkpoint_interval=exp_params['validation']['validate_agent_every_n_eps'],
                                log_interval=exp_params['validation']['log_interval'])

    mlflow_path[name] = f'{art_loc}/{run_id}/artifacts/{exp_name}/sb3/model.zip'

    i = i + 1


DQN


Output()

Moviepy - Building video env_HumanMoveSimple_exp_0209_122939//agent.mp4.
Moviepy - Writing video env_HumanMoveSimple_exp_0209_122939//agent.mp4



                                                               

Moviepy - Done !
Moviepy - video ready env_HumanMoveSimple_exp_0209_122939//agent.mp4
DDPG


Output()

In [19]:
mlflow_path

{'DDPG': 'mlflow-artifacts:/54/0db1a7bd6633464b9bc93dda987837a8/artifacts/env_HumanMoveSimple_exp_2508_115957/sb3/model.zip',
 'SAC': 'mlflow-artifacts:/54/deb23750b2e04dad878e96c9e0eb6524/artifacts/env_HumanMoveSimple_exp_2508_115957/sb3/model.zip',
 'TD3': 'mlflow-artifacts:/54/2cf8be691f4a4c869cd2b78e383ff68d/artifacts/env_HumanMoveSimple_exp_2508_115957/sb3/model.zip'}

In [26]:
m_name = 'DDPG'
local_path = f'/sb3/{m_name}/'
sac_path = mlflow_path.get(m_name)
print(sac_path)
mlflow_server.load_artifact(sac_path, local_path)
read_model = models[m_name].load(local_path + 'model.zip')

mlflow-artifacts:/54/0db1a7bd6633464b9bc93dda987837a8/artifacts/env_HumanMoveSimple_exp_2508_115957/sb3/model.zip


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

In [31]:
seed = int(datetime.now(TZ).strftime("%H%M%S"))
env_render = HumanMoveSimpleAction(render_mode = 'human', seed=seed)
total_reward = 0.
step_reward = []
observation, _ = env_render.reset()
terminated = False
truncated = False
while not terminated and not truncated:
    action, _ = read_model.predict(observation)
    observation, reward, terminated, truncated, _ = env_render.step(action)
    total_reward += reward
    step_reward.append(reward)

total_reward

-201.229332863059

In [32]:
env_render.close()