diff --git a/test/acquisition/test_knowledge_gradient.py b/test/acquisition/test_knowledge_gradient.py index 4510356cf6..d503d8b9d7 100644 --- a/test/acquisition/test_knowledge_gradient.py +++ b/test/acquisition/test_knowledge_gradient.py @@ -445,9 +445,10 @@ def test_evaluate_qMFKG(self): raw_samples=1, ) patch_f.asset_called_once() + cargs, ckwargs = patch_f.call_args self.assertTrue( ( - patch_f.call_args.kwargs["X"] + ckwargs["X"] == torch.ones(2, 1, 1, device=self.device, dtype=dtype) ).all() )