Skip to content

Commit

Permalink
Remove kronecker product function from tests
Browse files Browse the repository at this point in the history
  • Loading branch information
elaubsch committed Mar 23, 2023
1 parent 368d8be commit 4f11cf1
Showing 1 changed file with 2 additions and 18 deletions.
20 changes: 2 additions & 18 deletions deepcell_spots/decoding_functions_test.py
Expand Up @@ -33,7 +33,7 @@
from tensorflow.python.platform import test

from deepcell_spots.decoding_functions import (reshape_torch_array, decoding_function,
normalize_spot_values, kronecker_product,
normalize_spot_values,
chol_sigma_from_vec, mat_sqrt, rb_e_step,
gaussian_e_step, instantiate_rb_params,
instantiate_gaussian_params)
Expand Down Expand Up @@ -61,22 +61,6 @@ def test_normalize_spot_values(self):
self.assertEqual(data.shape, norm_data.shape)
pyro.get_param_store().clear()

def test_kronecker_product(self):
dim = 3
a = torch.zeros(dim, dim)
b = torch.zeros(dim, dim)
product_array = kronecker_product(a, b)
self.assertEqual(product_array.shape, (dim**2, dim**2))

a = torch.tensor([[1,2], [3,4]])
b = torch.tensor([[0,5], [6,7]])
product_array = kronecker_product(a, b)
expected_product = torch.tensor([[0,5,0,10],
[6,7,12,14],
[0,15,0,20],
[18,21,24,28]])
self.assertAllEqual(product_array, expected_product)

def test_chol_sigma_from_vec(self):
dim = 3
sigma_vec = torch.zeros(np.sum(np.arange(dim)+1))
Expand Down Expand Up @@ -206,7 +190,7 @@ def test_gaussian_e_step(self):
sigma_c = chol_sigma_from_vec(sigma_c_v, c)
sigma_r_v = torch.eye(r*c)[np.tril_indices(r, 0)]
sigma_r = chol_sigma_from_vec(sigma_r_v, r)
sigma = kronecker_product(sigma_r, sigma_c)
sigma = torch.kron(sigma_r, sigma_c)
class_prob_norm = gaussian_e_step(data, w, theta, sigma, k)
self.assertAllEqual(class_prob_norm.shape, torch.Size([n, k]))
self.assertAllInRange(class_prob_norm, 0, 1)
Expand Down

0 comments on commit 4f11cf1

Please sign in to comment.