Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Commit

Permalink
Add ParameterEditor to scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
colah committed Sep 22, 2019
1 parent d812e9c commit 050217b
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions lucid/scratch/parameter_editor.py
@@ -0,0 +1,40 @@
import numpy as np


class ParameterEditor():
"""Conveniently edit the parameters of a lucid model.
Example usage:
model = models.InceptionV1()
param = ParameterEditor(model.graph_def)
# Flip weights of first channel of conv2d0
param["conv2d0_w", :, :, :, 0] *= -1
"""

def __init__(self, graph_def):
self.nodes = {}
for node in graph_def.node:
if "value" in node.attr:
self.nodes[str(node.name)] = node

def __getitem__(self, key):
name = key[0] if isinstance(key, tuple) else key
tensor = self.nodes[name].attr["value"].tensor
shape = [int(d.size) for d in tensor.tensor_shape.dim]
array = np.frombuffer(tensor.tensor_content, dtype="float32").reshape(shape).copy()
return array[key[1:]] if isinstance(key, tuple) else array

def __setitem__(self, key, new_value):
name = key[0] if isinstance(key, tuple) else key
tensor = self.nodes[name].attr["value"].tensor
node_shape = [int(d.size) for d in tensor.tensor_shape.dim]
if isinstance(key, tuple):
array = np.frombuffer(tensor.tensor_content, dtype="float32")
array = array.reshape(node_shape).copy()
array[key[1:]] = new_value
tensor.tensor_content = array.tostring()
else:
assert new_value.shape == node_shape
tensor.tensor_content = new_value.tostring()

0 comments on commit 050217b

Please sign in to comment.