Skip to content

Commit

Permalink
Merge pull request #156 from atticusg/main
Browse files Browse the repository at this point in the history
Update visualization code for causal model and rotation matrices
  • Loading branch information
frankaging committed May 6, 2024
2 parents 1dc9243 + 49838a8 commit d29f959
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
28 changes: 28 additions & 0 deletions pyvene/analyses/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import seaborn
import torch

def rotation_token_heatmap(rotate_layer,
tokens,
token_size,
variables,
intervention_size):

W = rotate_layer.weight.data
in_dim, out_dim = W.shape

assert in_dim % token_size == 0
assert in_dim / token_size >= len(tokens)

assert out_dim % intervention_size == 0
assert out_dim / intervention_size >= len(variables)

heatmap = []
for j in range(len(variables)):
row = []
for i in range(len(tokens)):
row.append(torch.norm(W[i*token_size:(i+1)*token_size, j*intervention_size:(j+1)*intervention_size]))
mean = sum(row)
heatmap.append([x/mean for x in row])
return seaborn.heatmap(heatmap,
xticklabels=tokens,
yticklabels=variables)
21 changes: 10 additions & 11 deletions pyvene/data_generators/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ def __init__(
assert variable in self.values
assert variable in self.children
assert variable in self.functions
assert len(inspect.getfullargspec(self.functions[variable])[0]) == len(
self.parents[variable]
)
if timesteps is not None:
assert variable in timesteps
for variable2 in copy.copy(self.variables):
Expand Down Expand Up @@ -79,6 +76,8 @@ def __init__(
self.equiv_classes = equiv_classes
else:
self.equiv_classes = {}

def generate_equiv_classes(self):
for var in self.variables:
if var in self.inputs or var in self.equiv_classes:
continue
Expand Down Expand Up @@ -113,7 +112,7 @@ def generate_timesteps(self):
def marginalize(self, target):
pass

def print_structure(self, pos=None):
def print_structure(self, pos=None, font=12, node_size=1000):
G = nx.DiGraph()
G.add_edges_from(
[
Expand All @@ -123,7 +122,7 @@ def print_structure(self, pos=None):
]
)
plt.figure(figsize=(10, 10))
nx.draw_networkx(G, with_labels=True, node_color="green", pos=self.pos)
nx.draw_networkx(G, with_labels=True, node_color="green", pos=self.pos, font_size=font, node_size=node_size)
plt.show()

def find_live_paths(self, intervention):
Expand All @@ -149,12 +148,9 @@ def find_live_paths(self, intervention):
del paths[1]
return paths

def print_setting(self, total_setting, display=None):
labeler = lambda var: var + ": " + str(total_setting[var]) \
if display is None or display[var] \
else var
def print_setting(self, total_setting, font=12, node_size=1000):
relabeler = {
var: labeler(var) for var in self.variables
var: var + ": " + str(total_setting[var]) for var in self.variables
}
G = nx.DiGraph()
G.add_edges_from(
Expand All @@ -170,7 +166,7 @@ def print_setting(self, total_setting, display=None):
if self.pos is not None:
for var in self.pos:
newpos[relabeler[var]] = self.pos[var]
nx.draw_networkx(G, with_labels=True, node_color="green", pos=newpos)
nx.draw_networkx(G, with_labels=True, node_color="green", pos=newpos, font_size=font, node_size=node_size)
plt.show()

def run_forward(self, intervention=None):
Expand Down Expand Up @@ -233,11 +229,14 @@ def sample_input(self, mandatory=None):

def sample_input_tree_balanced(self, output_var=None, output_var_value=None):
assert output_var is not None or len(self.outputs) == 1
self.generate_equiv_classes()

if output_var is None:
output_var = self.outputs[0]
if output_var_value is None:
output_var_value = random.choice(self.values[output_var])


def create_input(var, value, input={}):
parent_values = random.choice(self.equiv_classes[var][value])
for parent in parent_values:
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/CausalModelTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def setUpClass(self):
self.parents,
self.functions
)
self.causal_model.generate_equiv_classes()

def test_initialization(self):
inputs = ['A', 'B']
Expand Down

0 comments on commit d29f959

Please sign in to comment.