Skip to content

Commit

Permalink
Add test cases multibatch and divisible by batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
elaubsch committed Apr 19, 2023
1 parent 3b99b2b commit a959188
Showing 1 changed file with 35 additions and 4 deletions.
39 changes: 35 additions & 4 deletions deepcell_spots/decoding_functions_test.py
Expand Up @@ -129,10 +129,21 @@ def test_instantiate_gaussian_params(self):
pyro.get_param_store().clear()

def test_rb_e_step(self):
# number of barcodes = 2, rounds = 2, channels = 3, spots = 100
# number of barcodes = 2, rounds = 2, channels = 3, spots = 99999
r = 2
c = 3
n = 100
n = 99999
spots = np.random.rand(n, r*c)
data = torch.tensor(spots)
codes = torch.tensor([[1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1]])
k = codes.shape[0]
w = torch.ones(k) / k
pyro.get_param_store().clear()

# number of barcodes = 2, rounds = 2, channels = 3, spots = 100000
r = 2
c = 3
n = 100000 # divisible by batch size
spots = np.random.rand(n, r*c)
data = torch.tensor(spots)
codes = torch.tensor([[1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1]])
Expand Down Expand Up @@ -176,10 +187,30 @@ def test_rb_e_step(self):
pyro.get_param_store().clear()

def test_gaussian_e_step(self):
# number of barcodes = 2, rounds = 2, channels = 3, spots = 100
# number of barcodes = 2, rounds = 2, channels = 3, spots = 99999
r = 2
c = 3
n = 100
n = 99999
spots = np.random.rand(n, r*c)
data = torch.tensor(spots)
codes = torch.tensor([[1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1]])
k = codes.shape[0]
w = torch.ones(k) / k
theta = torch.zeros(k, r*c)
sigma_c_v = torch.eye(c)[np.tril_indices(c, 0)]
sigma_c = chol_sigma_from_vec(sigma_c_v, c)
sigma_r_v = torch.eye(r*c)[np.tril_indices(r, 0)]
sigma_r = chol_sigma_from_vec(sigma_r_v, r)
sigma = torch.kron(sigma_r, sigma_c)
class_prob_norm = gaussian_e_step(data, w, theta, sigma, k)
self.assertAllEqual(class_prob_norm.shape, torch.Size([n, k]))
self.assertAllInRange(class_prob_norm, 0, 1)
pyro.get_param_store().clear()

# number of barcodes = 2, rounds = 2, channels = 3, spots = 100000
r = 2
c = 3
n = 100000 # divisible by batch size
spots = np.random.rand(n, r*c)
data = torch.tensor(spots)
codes = torch.tensor([[1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1]])
Expand Down

0 comments on commit a959188

Please sign in to comment.