diff --git a/structuretoolkit/__init__.py b/structuretoolkit/__init__.py index d8669e23a..6bc461501 100644 --- a/structuretoolkit/__init__.py +++ b/structuretoolkit/__init__.py @@ -88,6 +88,7 @@ center_coordinates_in_unit_cell, get_cell, get_extended_positions, + get_number_species_atoms, get_vertical_length, get_wrapped_coordinates, pymatgen_to_ase, @@ -113,6 +114,7 @@ "get_mean_positions", "get_neighborhood", "get_neighbors", + "get_number_species_atoms", "get_steinhardt_parameters", "get_strain", "get_symmetry", diff --git a/structuretoolkit/common/__init__.py b/structuretoolkit/common/__init__.py index 8c15b8aba..00a506e30 100644 --- a/structuretoolkit/common/__init__.py +++ b/structuretoolkit/common/__init__.py @@ -4,6 +4,7 @@ center_coordinates_in_unit_cell, get_cell, get_extended_positions, + get_number_species_atoms, get_vertical_length, get_wrapped_coordinates, select_index, @@ -22,6 +23,7 @@ "center_coordinates_in_unit_cell", "get_cell", "get_extended_positions", + "get_number_species_atoms", "get_vertical_length", "get_wrapped_coordinates", "select_index", diff --git a/structuretoolkit/common/helper.py b/structuretoolkit/common/helper.py index 0db936641..47029c959 100644 --- a/structuretoolkit/common/helper.py +++ b/structuretoolkit/common/helper.py @@ -5,6 +5,20 @@ from scipy.sparse import coo_matrix +def get_number_species_atoms(structure: Atoms): + """Returns a dictionary with the species in the structure and the corresponding count in the structure + + Args: + structure + + Returns: + dict: A dictionary with the species and the corresponding count + + """ + elements_lst = structure.get_chemical_symbols() + return {species: elements_lst.count(species) for species in set(elements_lst)} + + def get_extended_positions( structure: Atoms, width: float, diff --git a/tests/test_helpers.py b/tests/test_helpers.py index b7e5b7dfd..d754c3b83 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -22,6 +22,14 @@ def test_get_cell(self): stk.get_cell(np.arange(4)) with self.assertRaises(ValueError): stk.get_cell(np.ones((4, 3))) + + def test_get_number_species_atoms(self): + with self.subTest("Fe8"): + atoms = bulk("Fe").repeat(2) + self.assertEqual(stk.get_number_species_atoms(atoms), {'Fe': 8}) + with self.subTest('Al2Fe8'): + atoms = bulk('Fe').repeat(2) + bulk('Al').repeat((2,1,1)) + self.assertEqual(stk.get_number_species_atoms(atoms), {'Fe': 8, 'Al':2}) if __name__ == "__main__":