Skip to content

Commit

Permalink
Adding Tensorboard support.
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Feb 19, 2020
1 parent 4d071d6 commit 9a8e382
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 7 deletions.
4 changes: 4 additions & 0 deletions digideep/agent/ddpg/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def step(self):

monitor("/update/loss_actor", loss_actor.item())
monitor("/update/loss_critic", loss_critic.item())

self.session.writer.add_scalar('loss/actor', loss_actor.item())
self.session.writer.add_scalar('loss/critic', loss_critic.item())

self.state["i_step"] += 1

def update(self):
Expand Down
5 changes: 5 additions & 0 deletions digideep/agent/ppo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def step(self):
monitor("/update/value_loss", value_loss.item())
monitor("/update/action_loss", action_loss.item())
monitor("/update/dist_entropy", dist_entropy.item())

self.session.writer.add_scalar('loss/overall', Loss.item())
self.session.writer.add_scalar('loss/value', value_loss.item())
self.session.writer.add_scalar('loss/action', action_loss.item())
self.session.writer.add_scalar('loss/dist_entropy', dist_entropy.item())

## Candidates for monitoring
# ratio.item()
Expand Down
4 changes: 4 additions & 0 deletions digideep/agent/sac/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ def step(self):
monitor("/update/loss/softq", softq_loss.item())
monitor("/update/loss/value", value_loss.item())

self.session.writer.add_scalar('loss/actor', actor_loss.item())
self.session.writer.add_scalar('loss/softq', softq_loss.item())
self.session.writer.add_scalar('loss/value', value_loss.item())

# for key,item in locals().items():
# if isinstance(item, torch.Tensor):
# # print("item =", type(item))
Expand Down
4 changes: 4 additions & 0 deletions digideep/agent/sacv2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ def step(self):
monitor("/update/loss/critic1", qf1_loss.item())
monitor("/update/loss/critic2", qf2_loss.item())

self.session.writer.add_scalar('loss/actor', actor_loss.item())
self.session.writer.add_scalar('loss/critic1', qf1_loss.item())
self.session.writer.add_scalar('loss/critic2', qf2_loss.item())

# 'loss/entropy_loss', ent_loss: alpha_loss.item()
# 'entropy_temprature/alpha', alpha: alpha_tlogs.item()
self.state["i_step"] += 1
Expand Down
3 changes: 3 additions & 0 deletions digideep/environment/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Explorer:
def __init__(self, session, agents=None, **params):
self.agents = agents
self.params = params
self.session = session

# Create models
extra_env_kwargs = self.params.get("extra_env_kwargs", {})
Expand Down Expand Up @@ -124,7 +125,9 @@ def report_rewards(self, infos):
self.state["n_episode"] += 1

self.monitor_n_episode()

monitor("/reward/"+self.params["mode"]+"/episodic", rew, window=self.params["win_size"])
self.session.writer.add_scalar('reward/'+self.params["mode"], rew)

def close(self):
"""It closes all environments.
Expand Down
42 changes: 35 additions & 7 deletions digideep/pipeline/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(self, root_path):
self.state['path_checkpoints'] = os.path.join(self.state['path_session'], 'checkpoints')
self.state['path_monitor'] = os.path.join(self.state['path_session'], 'monitor')
self.state['path_videos'] = os.path.join(self.state['path_session'], 'videos')
self.state['path_tensorboard'] = os.path.join(self.state['path_session'], 'tensorboard')
# Hyper-parameters basically is a snapshot of intial parameter engine's state.
self.state['file_cpanel'] = os.path.join(self.state['path_session'], 'cpanel.json')
self.state['file_params'] = os.path.join(self.state['path_session'], 'params.yaml')
Expand All @@ -161,8 +162,8 @@ def __init__(self, root_path):
self.initLogger()
self.initVarlog()
self.initProlog()
if self.args["visdom"]:
self.initVisdom()
self.initTensorboard()
self.initVisdom()
# TODO: We don't need the "SaaM" when are loading from a checkpoint.
# if not self.is_playing:
if not self.is_loading:
Expand Down Expand Up @@ -202,6 +203,30 @@ def initProlog(self):
profiler.set_output_file(self.state['file_prolog'])
KeepTime.set_level(self.args["profiler_level"])

def initTensorboard(self):
"""
Will initialize the SummaryWriter for tensorboard logging.
Link: https://pytorch.org/docs/stable/tensorboard.html
"""
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(log_dir=self.state['path_tensorboard'])

if self.args["tensorboard"]:
# Run a dedicated Tensorboard server:
from tensorboard import program
tb = program.TensorBoard()
tb.configure(argv=[None, '--logdir', self.state['path_tensorboard']])
url = tb.launch()
logger.warn("Access Tensorboard through: " + str(url))
else:
# Nullify the attributes so time would not be wasted logging.
for attr in dir(self.writer):
if attr.startswith("add_") or (attr=="flush") or (attr=="close"):
setattr(self.writer, attr, lambda *args, **kw: None)



def initVisdom(self):
"""
This function initializes the connection to the Visdom server. The Visdom server must be running.
Expand All @@ -211,11 +236,12 @@ def initVisdom(self):
visdom -port 8097 &
"""
from digideep.utility.visdom_engine.Instance import VisdomInstance
if not self.dry_run:
VisdomInstance(port=self.args["visdom_port"], log_to_filename=self.state["file_visdom"], replay=True)
else:
VisdomInstance(port=self.args["visdom_port"])
if self.args["visdom"]:
from digideep.utility.visdom_engine.Instance import VisdomInstance
if not self.dry_run:
VisdomInstance(port=self.args["visdom_port"], log_to_filename=self.state["file_visdom"], replay=True)
else:
VisdomInstance(port=self.args["visdom_port"])

def createSaaM(self):
""" SaaM = Session-as-a-Module
Expand Down Expand Up @@ -354,6 +380,8 @@ def parse_arguments(self):
## Visdom Server
parser.add_argument('--visdom', action='store_true', help="Whether to use visdom or not!")
parser.add_argument('--visdom-port', metavar=('<n>'), default=8097, type=int, help="The port of visdom server, it's on 8097 by default.")
## Tensorboard
parser.add_argument('--tensorboard', action='store_true', help="Whether to use tensorboard or not!")
## Monitor Thread
parser.add_argument('--monitor-cpu', action="store_true", help="Use to monitor CPU resource statistics on Visdom.")
parser.add_argument('--monitor-gpu', action="store_true", help="Use to monitor GPU resource statistics on Visdom.")
Expand Down

0 comments on commit 9a8e382

Please sign in to comment.