Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for conv3d #52

Closed
2 of 4 tasks
dearkafka opened this issue Jul 4, 2017 · 7 comments
Closed
2 of 4 tasks

support for conv3d #52

dearkafka opened this issue Jul 4, 2017 · 7 comments
Labels

Comments

@dearkafka
Copy link

  • Check that you are up-to-date with the master branch of keras-vis. You can update with:
    pip install git+git://github.com/raghakot/keras-vis.git --upgrade --no-deps

  • If running on TensorFlow, check that you are up-to-date with the latest version. The installation instructions can be found here.

  • If running on Theano, check that you are up-to-date with the master branch of Theano. You can update with:
    pip install git+git://github.com/Theano/Theano.git --upgrade --no-deps

  • Provide a link to a GitHub Gist of a Python script that can reproduce your issue (or just copy the script here if it is short).

Hello, thanks for a great package, I found out that current version does not support 3d images (conv3d) which is expected, but it would be great if you could add this feature.

@raghakot
Copy link
Owner

raghakot commented Jul 4, 2017

It does. N dim inputs should work in fact. Do you have a gist that shows otherwise?

@dearkafka
Copy link
Author

dearkafka commented Jul 4, 2017

Yes. I'm trying basic example:

# The name of the layer we want to visualize
# (see model definition in vggnet.py)
layer_name = 'conv1'
layer_idx = [idx for idx, layer in enumerate(model.layers) if layer.name == layer_name][0]

# Visualize all filters in this layer.
filters = np.arange(get_num_filters(model.layers[layer_idx]))

# Generate input image for each filter. Here `text` field is used to overlay `filter_value` on top of the image.
vis_images = []
for idx in filters:
    img = visualize_activation(model, layer_idx, filter_indices=idx) 
    img = utils.draw_text(img, str(idx))
    vis_images.append(img)

# Generate stitched image palette with 8 cols.
stitched = utils.stitch_images(vis_images, cols=8)    
plt.axis('off')
plt.imshow(stitched)
plt.title(layer_name)
plt.show()

and it throws like that:

AssertionError                            Traceback (most recent call last)
<ipython-input-19-20fc5da0da03> in <module>()
     10 vis_images = []
     11 for idx in filters:
---> 12     img = visualize_activation(model, layer_idx, filter_indices=idx)
     13     img = utils.draw_text(img, str(idx))
     14     vis_images.append(img)

/usr/local/lib/python3.5/dist-packages/vis/visualization.py in visualize_activation(model, layer_idx, filter_indices, seed_img, text, act_max_weight, lp_norm_weight, tv_weight, **optimizer_params)
    108     ]
    109 
--> 110     opt = Optimizer(model.input, losses, norm_grads=False)
    111     img = opt.minimize(**optimizer_params)[0]
    112     if text:

/usr/local/lib/python3.5/dist-packages/vis/optimizer.py in __init__(self, img_input, losses, wrt, norm_grads)
     33             # Perf optimization. Don't build loss function with 0 weight.
     34             if weight != 0:
---> 35                 loss_fn = weight * loss.build_loss()
     36                 overall_loss = loss_fn if overall_loss is None else overall_loss + loss_fn
     37                 self.loss_names.append(loss.name)

/usr/local/lib/python3.5/dist-packages/vis/regularizers.py in build_loss(self)
     49         \left ( x(h+1, w, c) - x(h, w, c) \right )^{2} \right )^{\frac{\beta}{2}}$$
     50         """
---> 51         assert 4 == K.ndim(self.img)
     52         a = K.square(self.img[utils.slicer[:, :, 1:, :-1]] - self.img[utils.slicer[:, :, :-1, :-1]])
     53         b = K.square(self.img[utils.slicer[:, :, :-1, 1:]] - self.img[utils.slicer[:, :, :-1, :-1]])

AssertionError:

correct me if I'm wrong, it seems the problem with dimensions of tensors

@raghakot
Copy link
Owner

raghakot commented Jul 4, 2017

Ah I see what's going on. Python 3 pip is not up to date. That assert statement was very old code. Try installing from the source instead.

@dearkafka
Copy link
Author

Thank you, however, during same example (I changed visualize_activation => visualize_class_activation) I've got:

TypeError                                 Traceback (most recent call last)
<ipython-input-9-39a9eb8ab422> in <module>()
     11 for idx in filters:
     12     img = visualize_class_activation(model, layer_idx, filter_indices=idx)
---> 13     img = utils.draw_text(img, str(idx))
     14     vis_images.append(img)
     15 

/usr/local/lib/python3.5/dist-packages/vis/utils/utils.py in draw_text(img, text, position, font, font_size, color)
    228 
    229     # Don't mutate original image
--> 230     img = Image.fromarray(img)
    231     draw = ImageDraw.Draw(img)
    232     draw.text(position, text, fill=color, font=font)

/usr/local/lib/python3.5/dist-packages/PIL/Image.py in fromarray(obj, mode)
   2292         except KeyError:
   2293             # print(typekey)
-> 2294             raise TypeError("Cannot handle this data type")
   2295     else:
   2296         rawmode = mode

TypeError: Cannot handle this data type

@raghakot
Copy link
Owner

raghakot commented Jul 4, 2017

Probably comment out the utils.draw_text(img, str(idx)) that adds text on image. Is the input 2D image?

@dearkafka
Copy link
Author

Thank you, you are right, also I fixed stitched images to handle 3d->2d, and it's great!

@raghakot
Copy link
Owner

raghakot commented Aug 1, 2017

Can you PR the new and improved stitched images :D?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants