Skip to content

Commit

Permalink
Use tensorboard_pytorch to get summaries.
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed Aug 28, 2017
1 parent a669ea5 commit 5a291e3
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 125 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ Additional features not mentioned in the [report](https://arxiv.org/pdf/1702.021

### Prerequisites
- A basic pytorch installation. The code follows **0.2**. If you are using old **0.1.12**, you can checkout 0.1.12 branch.
- Python packages you might not have: `cython`, `opencv-python`, `easydict` (similar to [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn)). For `easydict` make sure you have the right version. I use 1.6.
- Python packages you might not have: `cffi`, `opencv-python`, `easydict` (similar to [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn)). For `easydict` make sure you have the right version. I use 1.6.
- [tensorboard-pytorch](https://github.com/lanpa/tensorboard-pytorch) to visualize the training and validation curve.
- ~~Docker users: Since the recent upgrade, the docker image on docker hub (https://hub.docker.com/r/mbuckler/tf-faster-rcnn-deps/) is no longer valid. However, you can still build your own image by using dockerfile located at `docker` folder (cuda 8 version, as it is required by Tensorflow r1.0.) And make sure following Tensorflow installation to install and use nvidia-docker[https://github.com/NVIDIA/nvidia-docker]. Last, after launching the container, you have to build the Cython modules within the running container.~~

### Installation
Expand Down
28 changes: 14 additions & 14 deletions lib/model/train_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorboard as tb

from model.config import cfg
import roi_data_layer.roidb as rdl_roidb
Expand Down Expand Up @@ -131,8 +131,8 @@ def construct_graph(self):
params += [{'params':[value],'lr':lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY}]
self.optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)
# Write the train and validation information to tensorboard
self.writer = tf.summary.FileWriter(self.tbdir)
self.valwriter = tf.summary.FileWriter(self.tbvaldir)
self.writer = tb.writer.FileWriter(self.tbdir)
self.valwriter = tb.writer.FileWriter(self.tbvaldir)

return lr, self.optimizer

Expand Down Expand Up @@ -208,7 +208,7 @@ def remove_snapshot(self, np_paths, ss_paths):
os.remove(str(sfile))
ss_paths.remove(sfile)

def train_model(self, sess, max_iters):
def train_model(self, max_iters):
# Build data layers for both training and validation set
self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)
Expand Down Expand Up @@ -252,12 +252,12 @@ def train_model(self, sess, max_iters):
if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
# Compute the graph with summary
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
self.net.train_step_with_summary(sess, blobs, self.optimizer)
self.writer.add_summary(summary, float(iter))
self.net.train_step_with_summary(blobs, self.optimizer)
for _sum in summary: self.writer.add_summary(_sum, float(iter))
# Also check the summary on the validation set
blobs_val = self.data_layer_val.forward()
summary_val = self.net.get_summary(sess, blobs_val)
self.valwriter.add_summary(summary_val, float(iter))
summary_val = self.net.get_summary(blobs_val)
for _sum in summary_val: self.valwriter.add_summary(_sum, float(iter))
last_summary_time = now
else:
# Compute the graph without summary
Expand Down Expand Up @@ -340,10 +340,10 @@ def train_net(network, imdb, roidb, valroidb, output_dir, tb_dir,
"""Train a Faster R-CNN network."""
roidb = filter_roidb(roidb)
valroidb = filter_roidb(valroidb)
with tf.Session(config=tf.ConfigProto(device_count = {'GPU': 0})) as sess:
sw = SolverWrapper(network, imdb, roidb, valroidb, output_dir, tb_dir,
pretrained_model=pretrained_model)

print('Solving...')
sw.train_model(sess, max_iters)
print('done solving')
sw = SolverWrapper(network, imdb, roidb, valroidb, output_dir, tb_dir,
pretrained_model=pretrained_model)

print('Solving...')
sw.train_model(max_iters)
print('done solving')
153 changes: 45 additions & 108 deletions lib/nets/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

from model.config import cfg

import tensorboard as tb

class Network(nn.Module):
def __init__(self, batch_size=1):
nn.Module.__init__(self)
Expand All @@ -41,39 +43,36 @@ def __init__(self, batch_size=1):
self._proposal_targets = {}
self._layers = {}
self._gt_image = None
self._act_summaries = []
self._act_summaries = {}
self._score_summaries = {}
self._train_summaries = {}
self._event_summaries = {}
self._image_gt_summaries = {}
self._variables_to_fix = {}

def _add_gt_image(self):
# add back mean
image = self._image_gt_summaries['image']['placeholder'] + cfg.PIXEL_MEANS
image = self._image_gt_summaries['image'] + cfg.PIXEL_MEANS
# BGR to RGB (opencv uses BGR)
self._gt_image = tf.reverse(image, axis=[-1])
self._gt_image = image[:,:,:,::-1].copy(order='C')

def _add_gt_image_summary(self):
# use a customized visualization function to visualize the boxes
if self._gt_image is None:
self._add_gt_image()
image = tf.py_func(draw_bounding_boxes,
[self._gt_image, self._image_gt_summaries['gt_boxes']['placeholder'], self._image_gt_summaries['im_info']['placeholder']],
tf.float32)
self._add_gt_image()
image = draw_bounding_boxes(\
self._gt_image, self._image_gt_summaries['gt_boxes'], self._image_gt_summaries['im_info'])

return tf.summary.image('GROUND_TRUTH', image)
return tb.summary.image('GROUND_TRUTH', torch.from_numpy(image[0].astype('float32')/ 255.0).permute(2,0,1))

def _add_act_summary(self, tensor):
tf.summary.histogram('ACT/' + tensor.op.name + '/activations', tensor)
tf.summary.scalar('ACT/' + tensor.op.name + '/zero_fraction',
tf.nn.zero_fraction(tensor))
def _add_act_summary(self, key, tensor):
return tb.summary.histogram('ACT/' + key + '/activations', tensor.data.cpu().numpy(), bins='auto'),
tb.summary.scalar('ACT/' + key + '/zero_fraction',
(tensor.data == 0).float().sum() / tensor.numel())

def _add_score_summary(self, key, tensor):
tf.summary.histogram('SCORE/' + tensor.op.name + '/' + key + '/scores', tensor)
return tb.summary.histogram('SCORE/' + key + '/scores', tensor.data.cpu().numpy(), bins='auto')

def _add_train_summary(self, var):
tf.summary.histogram('TRAIN/' + var.op.name, var)
def _add_train_summary(self, key, var):
return tb.summary.histogram('TRAIN/' + key, var.data.cpu().numpy(), bins='auto')

def _proposal_top_layer(self, rpn_cls_prob, rpn_bbox_pred):
rois, rpn_scores = proposal_top_layer(\
Expand Down Expand Up @@ -152,7 +151,7 @@ def _anchor_target_layer(self, rpn_cls_score):
self._anchor_targets['rpn_bbox_outside_weights'] = rpn_bbox_outside_weights

for k in self._anchor_targets.keys():
self._score_summaries[k]['value'] = self._anchor_targets[k]
self._score_summaries[k] = self._anchor_targets[k]

return rpn_labels

Expand All @@ -168,7 +167,7 @@ def _proposal_target_layer(self, rois, roi_scores):
self._proposal_targets['bbox_outside_weights'] = bbox_outside_weights

for k in self._proposal_targets.keys():
self._score_summaries[k]['value'] = self._proposal_targets[k]
self._score_summaries[k] = self._proposal_targets[k]

return rois, roi_scores

Expand Down Expand Up @@ -235,13 +234,13 @@ def _add_losses(self, sigma_rpn=3.0):
self._losses['total_loss'] = loss

for k in self._losses.keys():
self._event_summaries[k]['value'] = self._losses[k]
self._event_summaries[k] = self._losses[k]

return loss

def _region_proposal(self, net_conv):
rpn = F.relu(self.rpn_net(net_conv))
self._act_summaries['rpn']['value'] = rpn
self._act_summaries['rpn'] = rpn

rpn_cls_score = self.rpn_cls_score_net(rpn) # batch * (num_anchors * 2) * h * w

Expand Down Expand Up @@ -315,97 +314,35 @@ def create_architecture(self, num_classes, tag=None,

# Initialize layers
self._init_modules()
self._init_summary_op()

def _init_summary_op(self):
"""
Handle summaries Notes:
Here we still use original tensorflow tensorboard to do summary.
The way we send our result to summary, is we create placeholders for the values that needs summarized, and
created summary operators of these placeholders
Then during forwarding, we save the values.
To send it to the tensorboard, we run the summary operator by feeding the placeholders with
the saved values and get the summary.

"""

# Here we first create placeholders, and create summary operation.

# Manually add losses to event_summaries
for key in ['cross_entropy','loss_box','rpn_cross_entropy','rpn_loss_box','total_loss']:
self._event_summaries[key] = {'placeholder': tf.placeholder(tf.float32, shape=(), name=key)}

# Manually add losses to score_summaries
score_summaries_keys = []
# _anchor_targets
score_summaries_keys += ['rpn_labels','rpn_bbox_targets', 'rpn_bbox_inside_weights', 'rpn_bbox_outside_weights']
#_proposal_targets
score_summaries_keys += ['rois', 'labels', 'bbox_targets', 'bbox_inside_weights', 'bbox_outside_weights']
#_predictions
score_summaries_keys += ["rpn_cls_score", "rpn_cls_score_reshape", "rpn_cls_prob", "rpn_cls_pred", "rpn_bbox_pred", \
"cls_score", "cls_pred", "cls_prob", "bbox_pred", "rois"]
for key in score_summaries_keys:
self._score_summaries[key] = {'placeholder': tf.placeholder(tf.float32, name=key)}

# Manually add act_summaries
self._act_summaries = {'conv':{'placeholder': tf.placeholder(tf.float32, name='conv')},
'rpn':{'placeholder': tf.placeholder(tf.float32, name='rpn')}}


self._image_gt_summaries = {'image':{'placeholder': tf.placeholder(tf.float32, shape=[self._batch_size, None, None, 3])},\
'gt_boxes':{'placeholder': tf.placeholder(tf.float32, shape=[None, 5])},
'im_info':{'placeholder': tf.placeholder(tf.float32, shape=[None, 3])}}

# Add train summaries
for k, var in dict(self.named_parameters()).items():
if var.requires_grad:
self._train_summaries[k] = {'placeholder': tf.placeholder(tf.float32, name=k)}

val_summaries = []
with tf.device("/cpu:0"):
val_summaries.append(self._add_gt_image_summary())
for key, var in self._event_summaries.items():
val_summaries.append(tf.summary.scalar(key, var['placeholder']))
for key, var in self._score_summaries.items():
self._add_score_summary(key, var['placeholder'])
for var in self._act_summaries.values():
self._add_act_summary(var['placeholder'])
for var in self._train_summaries.values():
self._add_train_summary(var['placeholder'])

self._summary_op = tf.summary.merge_all()
self._summary_op_val = tf.summary.merge(val_summaries)


def _run_summary_op(self, sess, val=False):
def _run_summary_op(self, val=False):
"""
Run the summary operator: feed the placeholders with corresponding newtork outputs(activations)
"""
def delete_summaries_values(d):
# Delete the saved values to save memory, in case we have references of these computational graphs.
for _ in d.values(): del _['value']

feed_dict = {}
summaries = []
# Add image gt
feed_dict.update({_['placeholder']:_['value'] for _ in self._image_gt_summaries.values()})
delete_summaries_values(self._image_gt_summaries)
summaries.append(self._add_gt_image_summary())
# Add event_summaries
feed_dict.update({_['placeholder']:_['value'].data[0] for _ in self._event_summaries.values()})
delete_summaries_values(self._event_summaries)
for key, var in self._event_summaries.items():
summaries.append(tb.summary.scalar(key, var.data[0]))
self._event_summaries = {}
if not val:
# Add score summaries
feed_dict.update({_['placeholder']:_['value'].data.cpu().numpy() for _ in self._score_summaries.values()})
delete_summaries_values(self._score_summaries)
for key, var in self._score_summaries.items():
summaries.append(self._add_score_summary(key, var))
self._score_summaries = {}
# Add act summaries
feed_dict.update({_['placeholder']:_['value'].data.cpu().numpy() for _ in self._act_summaries.values()})
delete_summaries_values(self._act_summaries)
for key, var in self._act_summaries.items():
summaries += self._add_act_summary(key, var)
self._act_summaries = {}
# Add train summaries
for k, var in dict(self.named_parameters()).items():
if var.requires_grad:
feed_dict.update({self._train_summaries[k]['placeholder']: var.data.cpu().numpy()})
return sess.run(self._summary_op, feed_dict=feed_dict)
else:
return sess.run(self._summary_op_val, feed_dict=feed_dict)
summaries.append(self._add_train_summary(k, var))

self._image_gt_summaries = {}

return summaries

def _predict(self, mode):
# This is just _build_network in tf-faster-rcnn
Expand All @@ -425,14 +362,14 @@ def _predict(self, mode):
cls_prob, bbox_pred = self._region_classification(fc7)

for k in self._predictions.keys():
self._score_summaries[k]['value'] = self._predictions[k]
self._score_summaries[k] = self._predictions[k]

return rois, cls_prob, bbox_pred

def forward(self, image, im_info, gt_boxes=None, mode='TRAIN'):
self._image_gt_summaries['image']['value'] = image
self._image_gt_summaries['gt_boxes']['value'] = gt_boxes
self._image_gt_summaries['im_info']['value'] = im_info
self._image_gt_summaries['image'] = image
self._image_gt_summaries['gt_boxes'] = gt_boxes
self._image_gt_summaries['im_info'] = im_info

self._image = Variable(torch.from_numpy(image.transpose([0,3,1,2])).cuda(), volatile=mode == 'TEST')
self._im_info = im_info # No need to change; actually it can be an list
Expand Down Expand Up @@ -489,11 +426,11 @@ def delete_intermediate_states(self):
for k in d.keys():
del d[k]

def get_summary(self, sess, blobs):
def get_summary(self, blobs):
self.eval()
self.forward(blobs['data'], blobs['im_info'], blobs['gt_boxes'])
self.train()
summary = self._run_summary_op(sess, True)
summary = self._run_summary_op(True)

return summary

Expand All @@ -514,7 +451,7 @@ def train_step(self, blobs, train_op):

return rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss

def train_step_with_summary(self, sess, blobs, train_op):
def train_step_with_summary(self, blobs, train_op):
self.forward(blobs['data'], blobs['im_info'], blobs['gt_boxes'])
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss = self._losses["rpn_cross_entropy"].data[0], \
self._losses['rpn_loss_box'].data[0], \
Expand All @@ -524,7 +461,7 @@ def train_step_with_summary(self, sess, blobs, train_op):
train_op.zero_grad()
self._losses['total_loss'].backward()
train_op.step()
summary = self._run_summary_op(sess)
summary = self._run_summary_op()

self.delete_intermediate_states()

Expand Down
2 changes: 1 addition & 1 deletion lib/nets/resnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _crop_pool_layer(self, bottom, rois):

def _image_to_head(self):
net_conv = self._layers['head'](self._image)
self._act_summaries['conv']['value'] = net_conv
self._act_summaries['conv'] = net_conv

return net_conv

Expand Down
2 changes: 1 addition & 1 deletion lib/nets/vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _init_modules(self):

def _image_to_head(self):
net_conv = self._layers['head'](self._image)
self._act_summaries['conv']['value'] = net_conv
self._act_summaries['conv'] = net_conv

return net_conv

Expand Down

0 comments on commit 5a291e3

Please sign in to comment.