Skip to content

Commit 9ba0ab5

Browse files
committed
Try to upload best model to WandB during training
1 parent b617130 commit 9ba0ab5

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

rl_algo_impls/runner/train.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def train(args: TrainArgs):
4949
print(hyperparams)
5050
config = Config(args, hyperparams, os.getcwd())
5151

52-
wandb_enabled = args.wandb_project_name
52+
wandb_enabled = bool(args.wandb_project_name)
5353
if wandb_enabled:
5454
wandb.tensorboard.patch(
5555
root_logdir=config.tensorboard_summary_path, pytorch=True
@@ -108,6 +108,7 @@ def train(args: TrainArgs):
108108
else None,
109109
best_video_dir=config.best_videos_dir,
110110
additional_keys_to_log=config.additional_keys_to_log,
111+
wandb_enabled=wandb_enabled,
111112
)
112113
callbacks: List[Callback] = [eval_callback]
113114
if config.hyperparams.microrts_reward_decay_callback:
@@ -151,13 +152,8 @@ def train(args: TrainArgs):
151152

152153
if wandb_enabled:
153154
shutil.make_archive(
154-
os.path.join(wandb.run.dir, config.model_dir_name()),
155+
os.path.join(wandb.run.dir, config.model_dir_name()), # type: ignore
155156
"zip",
156157
config.model_dir_path(),
157158
)
158-
shutil.make_archive(
159-
os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
160-
"zip",
161-
config.model_dir_path(best=True),
162-
)
163159
wandb.finish()

rl_algo_impls/shared/callbacks/eval_callback.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
import os
3+
import shutil
34
from time import perf_counter
45
from typing import Dict, List, Optional, Union
56

@@ -132,6 +133,7 @@ def __init__(
132133
ignore_first_episode: bool = False,
133134
additional_keys_to_log: Optional[List[str]] = None,
134135
score_function: str = "mean-std",
136+
wandb_enabled: bool = False,
135137
) -> None:
136138
super().__init__()
137139
self.policy = policy
@@ -157,6 +159,7 @@ def __init__(
157159
self.ignore_first_episode = ignore_first_episode
158160
self.additional_keys_to_log = additional_keys_to_log
159161
self.score_function = score_function
162+
self.wandb_enabled = wandb_enabled
160163

161164
def on_step(self, timesteps_elapsed: int = 1) -> bool:
162165
super().on_step(timesteps_elapsed)
@@ -196,6 +199,15 @@ def evaluate(
196199
assert self.best_model_path
197200
self.policy.save(self.best_model_path)
198201
print("Saved best model")
202+
if self.wandb_enabled:
203+
import wandb
204+
205+
best_model_name = os.path.split(self.best_model_path)[-1]
206+
shutil.make_archive(
207+
os.path.join(wandb.run.dir, best_model_name), # type: ignore
208+
"zip",
209+
self.best_model_path,
210+
)
199211
self.best.write_to_tensorboard(
200212
self.tb_writer, "best_eval", self.timesteps_elapsed
201213
)

0 commit comments

Comments
 (0)