In [1]:
import os
import gym
import tqdm
import json
import mlflow
import pickle
import zipfile
import argparse
import datetime
from os import path
from dynaconf import Dynaconf
from ray.rllib.algorithms.dqn import DQN
from algorithms_with_statistics.basic_dqn import DQNWithLogging
from replay_buffer.ber import BlockReplayBuffer
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind
from utils import init_ray, check_path, logs_with_timeout, convert_np_arrays
from mlflow.exceptions import MlflowException
from func_timeout import FunctionTimedOut

In [4]:
checkpoint_path = "~/autodl-tmp/checkpoints/"
log_path = "~/autodl-tmp/loggings/"

In [3]:
init_ray("./ray_config.yml")

2023-06-28 02:05:05,673	INFO worker.py:1529 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


In [5]:
settings = Dynaconf(envvar_prefix="DYNACONF", settings_files="./settings/ex_experiments/AutoDL/atari_dqn.yaml")

In [6]:
env = "SpaceInvadersNoFrameskip"

In [7]:
run_name = "DQN_ER_" + datetime.datetime.now().strftime("%Y%m%d")

In [8]:
# Check path available
log_path = path.join(log_path, env)
check_path(log_path)
log_path = path.join(log_path, run_name)
check_path(log_path)
checkpoint_path = path.join(checkpoint_path, env)
check_path(checkpoint_path)
checkpoint_path = path.join(checkpoint_path, run_name)
check_path(checkpoint_path)

In [9]:
checkpoint_path

'~/autodl-tmp/checkpoints/SpaceInvadersNoFrameskip/DQN_ER_20230628'

In [10]:
hyper_parameters = settings["atari-basic-dqn"].config.to_dict()

In [11]:
# Set MLflow
mlflow.set_tracking_uri("https://seventheli-mlflow.eu.cpolar.io/")
mlflow.set_experiment(experiment_name="Simple-%s" % env)
mlflow_client = mlflow.tracking.MlflowClient()

In [12]:
mlflow_run = mlflow.start_run(run_name=run_name, tags={"mlflow.user": "AutoDL"})
mlflow.log_params(hyper_parameters["replay_buffer_config"])
mlflow.log_params({key: hyper_parameters[key] for key in hyper_parameters.keys() if key not in ["replay_buffer_config"]})

In [13]:
algorithm = DQN(config=hyper_parameters, env=env)

2023-06-28 02:05:48,630	INFO algorithm.py:501 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
  logger.warn(
A.L.E: Arcade Learning Environment (version 0.7.5+db37282)
[Powered by Stella]


In [14]:
with open(os.path.join(checkpoint_path, "%s config.pyl" % run_name), "wb") as f:
    _ = algorithm.config.to_dict()
    _.pop("multiagent")
    pickle.dump(_, f)
mlflow.log_artifacts(checkpoint_path)

In [15]:
keys_to_extract = {"episode_reward_max", "episode_reward_min", "episode_reward_mean"}

In [None]:
for i in tqdm.tqdm(range(1, 1000)):
    result = algorithm.train()
    time_used = result["time_total_s"]
    evaluation = result.get("evaluation", None)
    sampler = result.get("sampler_results", None)
    try:
        if evaluation is not None:
            _save = {"eval_" + key: evaluation[key] for key in keys_to_extract if key in evaluation}
            logs_with_timeout(_save, step=result["episodes_total"])
        if i % 10 == 0:
            learner_data = result["info"].copy()
            if learner_data["learner"].get("time_usage", None) is not None:
                logs_with_timeout(learner_data["learner"].get("time_usage"), step=result["episodes_total"])
            learner_data.pop("learner")
            logs_with_timeout(learner_data, step=result["episodes_total"])
            _save = {key: sampler[key] for key in keys_to_extract if key in sampler}
            logs_with_timeout(_save, step=result["episodes_total"])
        if i % 100 == 0:
            algorithm.save_checkpoint(checkpoint_path)
    except FunctionTimedOut:
        tqdm.tqdm.write("logging failed")
    except MlflowException:
        tqdm.tqdm.write("logging failed")
    with open(path.join(log_path, str(i) + ".json"), "w") as f:
        result["config"] = None
        json.dump(convert_np_arrays(result), f)
    if time_used >= 180000 or result["episode_reward_mean"] > 30000:
        break
with zipfile.ZipFile(os.path.join(log_path, '%s_log.zip' % run_name), 'w') as f:
    for file in os.listdir(log_path):
        f.write(os.path.join(log_path, file))
mlflow.log_artifacts(log_path)

  0%|          | 0/999 [00:00<?, ?it/s]