Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Apr 19, 2023
1 parent 4789eb2 commit 3d7719c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 11 deletions.
19 changes: 19 additions & 0 deletions test/models/test_gpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import sys

import numpy as np
from numpy.testing import assert_equal, assert_

sys.path.append("../../../")

from atomai.models import Reconstructor


def test_model_reconstruct():
X = np.abs(np.random.randn(10, 10)) + 1
X[2] = 0
t = Reconstructor()
assert_equal(len(t.train_loss), 0)
y_pred = t.reconstruct(X, 2)
assert_equal(len(t.train_loss), 2)
assert_(isinstance(y_pred, np.ndarray))
assert_(not np.array_equal(X, y_pred))
65 changes: 54 additions & 11 deletions test/trainers/test_gptrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

sys.path.append("../../../")

from atomai.trainers import dklGPTrainer
from atomai.trainers import dklGPTrainer, GPTrainer


def weights_equal(m1, m2):
Expand All @@ -20,10 +20,52 @@ def weights_equal(m1, m2):
return all(eq_w)




def test_gptrainer_compiler():
X = np.random.randn(10, 2)
y = np.random.randn(10)
t = GPTrainer(precision="single")
t.compile_trainer(X, y, 1)
assert_(t.gp_model is not None)
assert_(t.likelihood is not None)


def test_gptrainer_train():
h = w = 10
X = np.array([(i, j) for i in range(h) for j in range(w)])
y = np.random.randn(len(X))
t = GPTrainer()
t.compile_trainer(X, y)
params_init = dc(t.gp_model.base_covar_module.base_kernel.lengthscale.detach().cpu().numpy())
t.train_step()
params_final = t.gp_model.base_covar_module.base_kernel.lengthscale.detach().cpu().numpy()
assert_(not np.array_equal(params_init, params_final))


def test_gptrainer_compile_and_run():
h = w = 10
X = np.array([(i, j) for i in range(h) for j in range(w)])
y = np.random.randn(len(X))
t = GPTrainer()
t.compile_trainer(X, y, training_cycles=3)
_ = t.run()
assert_equal(len(t.train_loss), 3)


def test_gptrainer_run():
h = w = 10
X = np.array([(i, j) for i in range(h) for j in range(w)])
y = np.random.randn(len(X))
t = GPTrainer()
_ = t.run(X, y, 3)
assert_equal(len(t.train_loss), 3)


@pytest.mark.parametrize(
"precision, dtype",
[("single", torch.float32), ("double", torch.float64)])
def test_trainer_precision(precision, dtype):
def test_dkltrainer_precision(precision, dtype):
indim = 32
X = np.random.randn(50, indim)
y = np.random.randn(50)
Expand All @@ -33,7 +75,7 @@ def test_trainer_precision(precision, dtype):
assert_equal(y_.dtype, dtype)


def test_trainer_compiler():
def test_dkltrainer_compiler():
indim = 32
X = np.random.randn(50, indim)
y = np.random.randn(50)
Expand All @@ -43,7 +85,7 @@ def test_trainer_compiler():
assert_(t.likelihood is not None)


def test_multi_model_trainer_compiler():
def test_multi_model_dkltrainer_compiler():
indim = 32
X = np.random.randn(50, indim)
y = np.random.randn(2, 50)
Expand All @@ -53,7 +95,7 @@ def test_multi_model_trainer_compiler():
assert_(t.likelihood is not None)


def test_trainer_train():
def test_dkltrainer_train():
indim = 32
X = np.random.randn(50, indim)
y = np.random.randn(50)
Expand All @@ -65,7 +107,7 @@ def test_trainer_train():
assert_(not weights_equal(w_init, w_final))


def test_multi_model_trainer_train():
def test_multi_model_dkltrainer_train():
indim = 32
X = np.random.randn(50, indim)
y = np.random.randn(2, 50)
Expand All @@ -81,7 +123,7 @@ def test_multi_model_trainer_train():
assert_(not weights_equal(w_init2, w_final2))


def test_trainer_train_freeze_w():
def test_dkltrainer_train_freeze_w():
indim = 32
X = np.random.randn(50, indim)
y = np.random.randn(50)
Expand All @@ -93,7 +135,7 @@ def test_trainer_train_freeze_w():
assert_(weights_equal(w_init, w_final))


def test_trainer_compile_and_run():
def test_dkltrainer_compile_and_run():
indim = 32
X = np.random.randn(50, indim)
y = np.random.randn(50)
Expand All @@ -103,7 +145,7 @@ def test_trainer_compile_and_run():
assert_equal(len(t.train_loss), 3)


def test_ensemble_trainer_compile_and_run():
def test_ensemble_dkltrainer_compile_and_run():
indim = 32
X = np.random.randn(50, indim)
y = np.random.randn(1, 50).repeat(3, axis=0)
Expand All @@ -115,7 +157,7 @@ def test_ensemble_trainer_compile_and_run():
assert_(not weights_equal(w1, w2))


def test_trainer_run():
def test_dkltrainer_run():
indim = 32
X = np.random.randn(50, indim)
y = np.random.randn(50)
Expand All @@ -124,7 +166,7 @@ def test_trainer_run():
assert_equal(len(t.train_loss), 3)


def test_trainer_save_weights():
def test_dkltrainer_save_weights():
indim = 32
X = np.random.randn(50, indim)
y = np.random.randn(50)
Expand All @@ -136,3 +178,4 @@ def test_trainer_save_weights():
t.gp_model.feature_extractor.load_state_dict(loaded_weights)
w2 = t.gp_model.feature_extractor.state_dict()
assert_(weights_equal(w1, w2))

0 comments on commit 3d7719c

Please sign in to comment.