In [1]:
from stable_baselines3 import SAC, TD3, DDPG
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.sac.policies import MlpPolicy

In [None]:
# Create the model and the training environment
model = SAC("MlpPolicy", "Pendulum-v1", verbose=0, learning_rate=1e-3)

# train the model
model.learn(total_timesteps=2000, progress_bar=True)


In [None]:

# save the model
model.save("sac_pendulum")

# the saved model does not contain the replay buffer
loaded_model = SAC.load("sac_pendulum")
print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")

# now save the replay buffer too
model.save_replay_buffer("sac_replay_buffer")

# load it into the loaded_model
loaded_model.load_replay_buffer("sac_replay_buffer")

# now the loaded replay is not empty anymore
print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")

param = model.get_parameters()


# Save the policy independently from the model
# Note: if you don't save the complete model with `model.save()`
# you cannot continue training afterward
model.policy.save("policy_pendulum")
model.actor.save("actor_pendulum")
model.critic.save("critic_pendulum")

# Retrieve the environment
env = model.get_env()

# Evaluate the policy
mean_reward, std_reward = evaluate_policy(model.policy, env, n_eval_episodes=10, deterministic=True)

print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")

# Загрузить политику независимо от модели
saved_policy = MlpPolicy.load("policy_pendulum")

# Оценить загруженную политику
mean_reward,  std_reward = evaluate_policy(saved_policy, env, n_eval_episodes=10, deterministic=True)

print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")

In [None]:
model_td3 = TD3("MlpPolicy", "Pendulum-v1", verbose=0, learning_rate=1e-3)
model_td3.actor.load('actor_pendulum')

In [None]:
model_load = SAC("MlpPolicy", "Pendulum-v1")


model_load.policy.load("policy_pendulum")
model_load.actor.load_from_vector("actor_pendulum")
model_load.critic.load("critic_pendulum")

In [None]:
model_load.learn(total_timesteps=2000, log_interval=10, progress_bar=True)