Skip to content

Commit

Permalink
Fix Gradient Visualization (#66)
Browse files Browse the repository at this point in the history
* Pass model names to logger

* Fix the gradient visualizer

* Add getter for gradient visualizer

* Remove unnecessary dependency on optimizer

* Update trainer for new gradient visualizer

* Remove unnecessary references to gradsum

* Add support for None gradient and define end epoch ops

* Call appropriate end epoch ops for GradientVisualizer
  • Loading branch information
Aniket1998 authored and avik-pal committed Dec 8, 2018
1 parent 6b628a8 commit cb5b512
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 14 deletions.
10 changes: 7 additions & 3 deletions torchgan/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, trainer, losses_list, metrics_list=None, visdom_port=8097,
self.logger_mid_epoch = []
self.logger_end_epoch.append(ImageVisualize(trainer, writer=self.writer, test_noise=test_noise,
nrow=nrow))
self.logger_mid_epoch.append(GradientVisualize([], writer=self.writer))
self.logger_mid_epoch.append(GradientVisualize(trainer.model_names, writer=self.writer))
if metrics_list is not None:
self.logger_end_epoch.append(MetricVisualize(metrics_list, writer=self.writer))
self.logger_mid_epoch.append(LossVisualize(losses_list, writer=self.writer))
Expand All @@ -29,6 +29,9 @@ def get_loss_viz(self):
def get_metric_viz(self):
return self.logger_end_epoch[0]

def get_grad_viz(self):
return self.logger_mid_epoch[0]

def register(self, visualize, *args, mid_epoch=True, **kwargs):
if mid_epoch:
self.logger_mid_epoch.append(visualize(*args, writer=self.writer, **kwargs))
Expand All @@ -49,9 +52,10 @@ def run_mid_epoch(self, trainer, *args):
def run_end_epoch(self, trainer, epoch, *args):
print("Epoch {} Summary".format(epoch))
for logger in self.logger_mid_epoch:
if type(logger).__name__ is "LossVisualize" or\
type(logger).__name__ is "GradientVisualize":
if type(logger).__name__ is "LossVisualize":
logger(trainer)
elif type(logger).__name__ is "GradientVisualize":
logger.report_end_epoch()
else:
logger(*args)
for logger in self.logger_end_epoch:
Expand Down
45 changes: 34 additions & 11 deletions torchgan/logging/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,23 +140,46 @@ def log_visdom(self):
opts=dict(title=name, xlabel="Time Step", ylabel="Metric Value"))

class GradientVisualize(Visualize):
def log_tensorboard(self, name, gradsum):
self.writer.add_scalar('Gradients/{}'.format(name), gradsum, self.step)
def __init__(self, visualize_list, visdom_port=8097, log_dir=None, writer=None):
if visualize_list is None or len(visualize_list) == 0:
raise Exception('Gradient Visualizer requires list of model names')
self.logs = {}
for item in visualize_list:
self.logs[item] = [0.0]
self.step = 1
if TENSORBOARD_LOGGING == 1:
self.build_tensorboard(log_dir, writer)
if VISDOM_LOGGING == 1:
self.build_visdom(visdom_port)

def log_tensorboard(self, name):
self.writer.add_scalar('Gradients/{}'.format(name), self.logs[name][len(self.logs[name]) - 1], self.step)

def log_console(self, name, gradsum):
print('{} Gradients : {}'.format(name, gradsum))
def log_console(self, name):
print('{} Gradients : {}'.format(name, self.logs[name][len(self.logs[name]) - 1]))

def log_visdom(self, name, gradsum):
self.vis.line([gradsum], [self.step], win=name, update="append",
def log_visdom(self, name):
self.vis.line([self.logs[name][len(self.logs[name]) - 1]], [self.step], win=name, update="append",
opts=dict(title=name, xlabel="Time Step", ylabel="Gradient"))

def update_grads(self, name, model, eps=1e-5):
gradsum = 0.0
for p in model.parameters():
if p.grad is not None:
gradsum += torch.sum(p.grad ** 2).clone().item()
if gradsum > eps:
self.logs[name][len(self.logs[name]) - 1] += gradsum
model.zero_grad()

def report_end_epoch(self):
for key, val in self.logs.items():
print('{} Mean Gradients : {}'.format(key, sum(val) / len(val)))

def __call__(self, trainer, **kwargs):
for name in trainer.model_names:
model = getattr(trainer, name)
gradsum = 0.0
for p in model.parameters():
gradsum += p.norm(2).item()
super(GradientVisualize, self).__call__(name, gradsum, **kwargs)
super(GradientVisualize, self).__call__(name, **kwargs)
self.logs[name].append(0.0)


class ImageVisualize(Visualize):
def __init__(self, trainer, visdom_port=8097, log_dir=None, writer=None, test_noise=None, nrow=8):
Expand Down
14 changes: 14 additions & 0 deletions torchgan/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,23 +305,34 @@ def train_iter(self):
self.train_iter_custom()
ldis, lgen, dis_iter, gen_iter = 0.0, 0.0, 0, 0
loss_logs = self.logger.get_loss_viz()
grad_logs = self.logger.get_grad_viz()
for name, loss in self.losses.items():
if isinstance(loss, GeneratorLoss) and isinstance(loss, DiscriminatorLoss):
cur_loss = loss.train_ops(**self._get_arguments(self.loss_arg_maps[name]))
loss_logs.logs[name].append(cur_loss)
if type(cur_loss) is tuple:
lgen, ldis, gen_iter, dis_iter = lgen + cur_loss[0], ldis + cur_loss[1],\
gen_iter + 1, dis_iter + 1
for model_name in self.model_names:
grad_logs.update_grads(model_name, getattr(self, model_name))
elif isinstance(loss, GeneratorLoss):
if self.ncritic is None or\
self.loss_information["discriminator_iters"] % self.ncritic == 0:
cur_loss = loss.train_ops(**self._get_arguments(self.loss_arg_maps[name]))
loss_logs.logs[name].append(cur_loss)
lgen, gen_iter = lgen + cur_loss, gen_iter + 1
for model_name in self.model_names:
model = getattr(self, model_name)
if isinstance(model, Generator):
grad_logs.update_grads(model_name, model)
elif isinstance(loss, DiscriminatorLoss):
cur_loss = loss.train_ops(**self._get_arguments(self.loss_arg_maps[name]))
loss_logs.logs[name].append(cur_loss)
ldis, dis_iter = ldis + cur_loss, dis_iter + 1
for model_name in self.model_names:
model = getattr(self, model_name)
if isinstance(model, Discriminator):
grad_logs.update_grads(model_name, model)
return lgen, ldis, gen_iter, dis_iter

def eval_ops(self, epoch, **kwargs):
Expand Down Expand Up @@ -353,6 +364,9 @@ def train(self, data_loader, **kwargs):
data_loader (torch.DataLoader): A DataLoader for the trainer to iterate over and train the
models.
"""
for name in self.optimizer_names:
getattr(self, name).zero_grad()

for epoch in range(self.start_epoch, self.epochs):

for model in self.model_names:
Expand Down

0 comments on commit cb5b512

Please sign in to comment.