Skip to content

Commit

Permalink
Merge pull request #16 from tfjgeorge/pvector_dot
Browse files Browse the repository at this point in the history
adds dot product to PVector
  • Loading branch information
tfjgeorge committed Mar 28, 2021
2 parents 911e55f + 5993154 commit ba043fc
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
19 changes: 19 additions & 0 deletions nngeometry/object/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,25 @@ def __sub__(self, other):
vector_repr=(self.get_flat_representation() -
other.get_flat_representation()))

def dot(self, other):
"""
Computes the dot product between `self` and `other`
:param other: The other `PVector`
"""
if self.vector_repr is not None or other.vector_repr is not None:
return torch.dot(self.get_flat_representation(),
other.get_flat_representation())
else:
dot_ = 0
for l_id, l in self.layer_collection.layers.items():
if l.bias is not None:
dot_ += torch.dot(self.dict_repr[l_id][1],
other.dict_repr[l_id][1])
dot_ += torch.dot(self.dict_repr[l_id][0].view(-1),
other.dict_repr[l_id][0].view(-1))
return dot_


class FVector:
"""
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch==1.8.0
torchvision==0.9.0
requests==2.24.0
torch==1.8.1
torchvision==0.9.1
requests==2.24.0
32 changes: 31 additions & 1 deletion tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,34 @@ def test_from_to_model():
# now model2 and model3 should be the same

for p2, p3 in zip(model2.parameters(), model3.parameters()):
check_tensors(p2, p3)
check_tensors(p2, p3)


def test_dot():
model = ConvNet()
layer_collection = LayerCollection.from_model(model)
r1 = random_pvector(layer_collection)
r2 = random_pvector(layer_collection)
dotr1r2 = r1.dot(r2)
check_ratio(torch.dot(r1.get_flat_representation(),
r2.get_flat_representation()),
dotr1r2)

r1 = random_pvector_dict(layer_collection)
r2 = random_pvector_dict(layer_collection)
dotr1r2 = r1.dot(r2)
check_ratio(torch.dot(r1.get_flat_representation(),
r2.get_flat_representation()),
dotr1r2)


r1 = random_pvector(layer_collection)
r2 = random_pvector_dict(layer_collection)
dotr1r2 = r1.dot(r2)
dotr2r1 = r2.dot(r1)
check_ratio(torch.dot(r1.get_flat_representation(),
r2.get_flat_representation()),
dotr1r2)
check_ratio(torch.dot(r1.get_flat_representation(),
r2.get_flat_representation()),
dotr2r1)

0 comments on commit ba043fc

Please sign in to comment.