diff --git a/tests/test_basic.py b/tests/test_basic.py index ac54b29..1cc7eaf 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -16,23 +16,28 @@ def test_auto_martini_imported(): assert "auto_martini" in sys.modules -@pytest.mark.parametrize("sdf_file", dpath.glob("*.sdf")) -def test_auto_martini_run_sdf(sdf_file): +@pytest.mark.parametrize( + "sdf_file,num_beads", + [ + (dpath / "benzene.sdf", 3), + (dpath / "ibuprofen.sdf", 5) + ] +) +def test_auto_martini_run_sdf(sdf_file: str, num_beads: int): mol = auto_martini.topology.gen_molecule_sdf(str(sdf_file)) - auto_martini.solver.cg_molecule(mol, "MOL", "mol.top") - Path("mol.top").unlink() + cg_mol = auto_martini.solver.Cg_molecule(mol, "MOL") + assert len(cg_mol.cg_bead_names) == num_beads @pytest.mark.parametrize( - "smiles,top_file,name", + "smiles,top_file,name,num_beads", [ - ("N1=C(N)NNC1N", "valid_GUA.top", "GUA"), - ("CCC", "valid_PRO.top", "PRO"), + ("N1=C(N)NNC1N", "valid_GUA.top", "GUA", 2), + ("CCC", "valid_PRO.top", "PRO", 1), ] ) -def test_auto_martini_run_smiles(smiles: str, top_file: Path, name: str): +def test_auto_martini_run_smiles(smiles: str, top_file: Path, name: str, num_beads: int): mol, _ = auto_martini.topology.gen_molecule_smi(smiles) - auto_martini.solver.cg_molecule(mol, name, "mol.top") + cg_mol = auto_martini.solver.Cg_molecule(mol, name) # assert filecmp.cmp(dpath / top_file, "mol.top") - # need to figure out a better way of comparing contents - Path("mol.top").unlink() + assert len(cg_mol.cg_bead_names) == num_beads