Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Commit

Permalink
Increment version to 0.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwigschubert committed May 15, 2018
1 parent eee40f3 commit e830c8f
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 12 deletions.
18 changes: 17 additions & 1 deletion README.md
Expand Up @@ -36,7 +36,7 @@ run in your browser.
<img src="https://storage.googleapis.com/lucid-static/common/stickers/colab-tutorial.png" width="500" alt=""></img>
</a>

## Building Blocks
## Building Blocks
*Notebooks corresponding to the [Building Blocks of Interpretability](https://distill.pub/2018/building-blocks/) article*


Expand Down Expand Up @@ -80,6 +80,22 @@ This project is research code. It is not an official Google product.

## Development

### Style guide deviations

We use naming conventions to help differentiate tensors, operations, and values:

* Suffix variable names representing **tensors** with `_t`
* Suffix variable names representing **operations** with `_op`
* Don't suffix variable names representing concrete values

Usage example:

```
global_step_t = tf.train.get_or_create_global_step()
global_step_init_op = tf.variables_initializer([global_step_t])
global_step = global_step_t.eval()
```

### Running Tests

Use `tox` to run the test suite on all supported environments.
Expand Down
1 change: 1 addition & 0 deletions lucid/misc/io/serialize_array.py
Expand Up @@ -48,6 +48,7 @@ def _normalize_array(array, domain=(0, 1)):
array = np.squeeze(array)
assert len(array.shape) <= 3
assert np.issubdtype(array.dtype, np.number)
assert not np.isnan(array).any()

low, high = np.min(array), np.max(array)
if domain is None:
Expand Down
23 changes: 20 additions & 3 deletions lucid/misc/redirected_relu_grad.py
Expand Up @@ -23,13 +23,20 @@
These functions provide a more convenient solution: temporarily override the
gradient of ReLUs to allow gradient to flow back through the ReLU -- even if it
didn't activate and had a derivative of zero -- allowing the visualization
process to get started.
process to get started. These functions override the gradient for at most 16
steps. Thus, you need to initialize `global_step` before using these functions.
Usage:
```python
from lucid.misc.gradient_override import gradient_override_map
from lucid.misc.redirected_relu_grad import redirected_relu_grad
...
global_step_t = tf.train.get_or_create_global_step()
init_global_step_op = tf.variables_initializer([global_step_t])
init_global_step_op.run()
...
with gradient_override_map({'Relu': redirected_relu_grad}):
model.import_graph(...)
```
Expand Down Expand Up @@ -99,7 +106,12 @@ def redirected_relu_grad(op, grad):
batch = tf.shape(relu_grad)[0]
reshaped_relu_grad = tf.reshape(relu_grad, [batch, -1])
relu_grad_mag = tf.norm(reshaped_relu_grad, axis=1)
return tf.where(relu_grad_mag > 0., relu_grad, redirected_grad)
result_grad = tf.where(relu_grad_mag > 0., relu_grad, redirected_grad)

global_step_t =tf.train.get_or_create_global_step()
return_relu_grad = tf.greater(global_step_t, tf.constant(16, tf.int64))

return tf.where(return_relu_grad, relu_grad, result_grad)


def redirected_relu6_grad(op, grad):
Expand All @@ -125,4 +137,9 @@ def redirected_relu6_grad(op, grad):
batch = tf.shape(relu_grad)[0]
reshaped_relu_grad = tf.reshape(relu_grad, [batch, -1])
relu_grad_mag = tf.norm(reshaped_relu_grad, axis=1)
return tf.where(relu_grad_mag > 0., relu_grad, redirected_grad)
result_grad = tf.where(relu_grad_mag > 0., relu_grad, redirected_grad)

global_step_t = tf.train.get_or_create_global_step()
return_relu_grad = tf.greater(global_step_t, tf.constant(16, tf.int64))

return tf.where(return_relu_grad, relu_grad, result_grad)
18 changes: 17 additions & 1 deletion lucid/optvis/objectives.py
Expand Up @@ -131,7 +131,23 @@ def wrap_objective(f, *args, **kwds):

@wrap_objective
def neuron(layer_name, channel_n, x=None, y=None, batch=None):
"""Visualize a single neuron of a single channel."""
"""Visualize a single neuron of a single channel.
Defaults to the center neuron. When width and height are even numbers, we
choose the neuron in the bottom right of the center 2x2 neurons.
Odd width & height: Even width & height:
+---+---+---+ +---+---+---+---+
| | | | | | | | |
+---+---+---+ +---+---+---+---+
| | X | | | | | | |
+---+---+---+ +---+---+---+---+
| | | | | | | X | |
+---+---+---+ +---+---+---+---+
| | | | |
+---+---+---+---+
"""
def inner(T):
layer = T(layer_name)
shape = tf.shape(layer)
Expand Down
32 changes: 26 additions & 6 deletions lucid/optvis/render.py
Expand Up @@ -31,6 +31,8 @@

from lucid.optvis import objectives, param, transform
from lucid.misc.io import show
from lucid.misc.redirected_relu_grad import redirected_relu_grad, redirected_relu6_grad
from lucid.misc.gradient_override import gradient_override_map

# pylint: disable=invalid-name

Expand All @@ -40,8 +42,8 @@


def render_vis(model, objective_f, param_f=None, optimizer=None,
transforms=None, thresholds=(512,),
print_objectives=None, verbose=True,):
transforms=None, thresholds=(512,), print_objectives=None,
verbose=True, relu_gradient_override=True, use_fixed_seed=False):
"""Flexible optimization-base feature vis.
There's a lot of ways one might wish to customize otpimization-based
Expand Down Expand Up @@ -72,6 +74,11 @@ def render_vis(model, objective_f, param_f=None, optimizer=None,
whose values get logged during the optimization.
verbose: Should we display the visualization when we hit a threshold?
This should only be used in IPython.
relu_gradient_override: Whether to use the gradient override scheme
described in lucid/misc/redirected_relu_grad.py. On by default!
use_fixed_seed: Seed the RNG with a fixed value so results are reproducible.
Off by default. As of tf 1.8 this does not work as intended, see:
https://github.com/tensorflow/tensorflow/issues/9171
Returns:
2D array of optimization results containing of evaluations of supplied
param_f snapshotted at specified thresholds. Usually that will mean one or
Expand All @@ -80,7 +87,11 @@ def render_vis(model, objective_f, param_f=None, optimizer=None,

with tf.Graph().as_default() as graph, tf.Session() as sess:

T = make_vis_T(model, objective_f, param_f, optimizer, transforms)
if use_fixed_seed: # does not mean results are reproducible, see Args doc
tf.set_random_seed(0)

T = make_vis_T(model, objective_f, param_f, optimizer, transforms,
relu_gradient_override)
print_objective_func = make_print_objective_func(print_objectives, T)
loss, vis_op, t_image = T("loss"), T("vis_op"), T("input")
tf.global_variables_initializer().run()
Expand All @@ -105,7 +116,7 @@ def render_vis(model, objective_f, param_f=None, optimizer=None,


def make_vis_T(model, objective_f, param_f=None, optimizer=None,
transforms=None):
transforms=None, relu_gradient_override=False):
"""Even more flexible optimization-base feature vis.
This function is the inner core of render_vis(), and can be used
Expand Down Expand Up @@ -155,10 +166,19 @@ def make_vis_T(model, objective_f, param_f=None, optimizer=None,
transform_f = make_transform_f(transforms)
optimizer = make_optimizer(optimizer, [])

T = import_model(model, transform_f(t_image), t_image)
global_step = tf.train.get_or_create_global_step()
init_global_step = tf.variables_initializer([global_step])
init_global_step.run()

if relu_gradient_override:
with gradient_override_map({'Relu': redirected_relu_grad,
'Relu6': redirected_relu6_grad}):
T = import_model(model, transform_f(t_image), t_image)
else:
T = import_model(model, transform_f(t_image), t_image)
loss = objective_f(T)

global_step = tf.Variable(0, trainable=False, name="global_step")

vis_op = optimizer.minimize(-loss, global_step=global_step)

local_vars = locals()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -18,7 +18,7 @@

from setuptools import setup, find_packages

version = '0.0.8'
version = '0.1.0'

test_deps = [
'future',
Expand Down
32 changes: 32 additions & 0 deletions tests/misc/test_gradient_override.py
Expand Up @@ -29,6 +29,10 @@ def gradient_override(op, grad):
return tf.constant(42)

with tf.Session().as_default() as sess:
global_step = tf.train.get_or_create_global_step()
init_global_step = tf.variables_initializer([global_step])
init_global_step.run()

a = tf.constant(1.)
standard_relu = tf.nn.relu(a)
grad_wrt_a = tf.gradients(standard_relu, a, [1.])[0]
Expand Down Expand Up @@ -56,10 +60,38 @@ def test_gradient_override_relu6_directionality(nonl_name, nonl,
nonl_grad_override, examples):
for incoming_grad, input, grad in examples:
with tf.Session().as_default() as sess:
global_step = tf.train.get_or_create_global_step()
init_global_step = tf.variables_initializer([global_step])
init_global_step.run()

batched_shape = [1,1]
incoming_grad_t = tf.constant(incoming_grad, shape=batched_shape)
input_t = tf.constant(input, shape=batched_shape)
with gradient_override_map({nonl_name: nonl_grad_override}):
nonl_t = nonl(input_t)
grad_wrt_input = tf.gradients(nonl_t, input_t, [incoming_grad_t])[0]
assert (grad_wrt_input.eval() == grad).all()

@pytest.mark.parametrize("nonl_name,nonl,nonl_grad_override, examples", nonls)
def test_gradient_override_shutoff(nonl_name, nonl,
nonl_grad_override, examples):
for incoming_grad, input, grad in examples:
with tf.Session().as_default() as sess:
global_step_t = tf.train.get_or_create_global_step()
global_step_init_op = tf.variables_initializer([global_step_t])
global_step_init_op.run()
global_step_assign_t = tf.assign(global_step_t, 17)
sess.run(global_step_assign_t)

# similar setup to test_gradient_override_relu6_directionality,
# but we test that the gradient is *not* what we're expecting as after 16
# steps the override is shut off
batched_shape = [1,1]
incoming_grad_t = tf.constant(incoming_grad, shape=batched_shape)
input_t = tf.constant(input, shape=batched_shape)
with gradient_override_map({nonl_name: nonl_grad_override}):
nonl_t = nonl(input_t)
grad_wrt_input = tf.gradients(nonl_t, input_t, [incoming_grad_t])[0]
nonl_t_no_override = nonl(input_t)
grad_wrt_input_no_override = tf.gradients(nonl_t_no_override, input_t, [incoming_grad_t])[0]
assert (grad_wrt_input.eval() == grad_wrt_input_no_override.eval()).all()
5 changes: 5 additions & 0 deletions tests/optvis/test_integration.py
Expand Up @@ -20,4 +20,9 @@ def test_integration(decorrelate, fft):
verbose=False, transforms=[])
start_image = rendering[0]
end_image = rendering[-1]
objective_f = objectives.neuron("mixed3a", 177)
param_f = lambda: param.image(64, decorrelate=decorrelate, fft=fft)
rendering = render.render_vis(model, objective_f, param_f, verbose=False, thresholds=(0,64), use_fixed_seed=True)
start_image, end_image = rendering

assert (start_image != end_image).any()

0 comments on commit e830c8f

Please sign in to comment.