Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
avaucher committed Aug 2, 2023
1 parent 350ea8e commit 56d8a7d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ include_package_data = True
install_requires =
click>=8.0
pandas>=1.1.0
rxn-chem-utils>=1.2.0
rxn-chem-utils>=1.3.0
rxn-onmt-models>=1.0.0
rxn-onmt-utils>=1.0.0
rxn-utils>=1.1.9
Expand Down
27 changes: 21 additions & 6 deletions src/rxn/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
from typing import Iterator, Sequence, TypeVar

from rxn.utilities.files import PathLike, iterate_lines_from_file
from rxn.utilities.misc import get_multiplier
from rxn.chemutils.reaction_combiner import ReactionCombiner
from rxn.chemutils.reaction_smiles import ReactionFormat
from rxn.utilities.files import PathLike, count_lines, iterate_lines_from_file
from rxn.utilities.misc import get_multiplier, get_multipliers

T = TypeVar("T")


def combine_precursors_and_products(
precursors: Iterator[str], products: Iterator[str]
precursors: Iterator[str],
products: Iterator[str],
total_precursors: int,
total_products: int,
) -> Iterator[str]:
"""
Combine two matching iterables of precursors/products into an iterator of reaction SMILES.
Args:
precursors: iterable of sets of precursors.
products: iterable of sets of products.
total_precursors: total number of precursors.
total_products: total number of products.
Returns:
iterator over reaction SMILES.
"""
combiner = ReactionCombiner(reaction_format=ReactionFormat.STANDARD_WITH_TILDE)

yield from (
f"{precursor_set}>>{product_set}"
for precursor_set, product_set in zip(precursors, products)
precursor_multiplier, product_multiplier = get_multipliers(
total_precursors, total_products
)

yield from combiner.combine_iterators(
precursors, products, precursor_multiplier, product_multiplier
)


Expand All @@ -39,10 +50,14 @@ def combine_precursors_and_products_from_files(
Returns:
iterator over reaction SMILES.
"""
n_precursors = count_lines(precursors_file)
n_products = count_lines(products_file)

yield from combine_precursors_and_products(
precursors=iterate_lines_from_file(precursors_file),
products=iterate_lines_from_file(products_file),
total_precursors=n_precursors,
total_products=n_products,
)


Expand Down
30 changes: 29 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,34 @@
import pytest
from rxn.utilities.files import dump_list_to_file, named_temporary_directory

from rxn.metrics.utils import get_sequence_multiplier
from rxn.metrics.utils import (
combine_precursors_and_products_from_files,
get_sequence_multiplier,
)


def test_combine_precursors_and_products_from_files() -> None:
# Make sure that things are combined properly with the precursors file
# containing twice as many lines as the products file.
with named_temporary_directory() as tmp_dir:
precursors_file = tmp_dir / "a"
products_file = tmp_dir / "b"

dump_list_to_file(
["CC.O", "CC.O.[Na+]~[Cl-]", "CCC.O", "NS.CCC.O"],
precursors_file,
)
dump_list_to_file(["CCO", "CCCO"], products_file)

results = combine_precursors_and_products_from_files(
precursors_file, products_file
)
assert list(results) == [
"CC.O>>CCO",
"CC.O.[Na+]~[Cl-]>>CCO",
"CCC.O>>CCCO",
"NS.CCC.O>>CCCO",
]


def test_get_sequence_multiplier() -> None:
Expand Down

0 comments on commit 56d8a7d

Please sign in to comment.