Skip to content

Commit

Permalink
Add the following metric: true reactant accuracy (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
avaucher authored Aug 2, 2023
1 parent 703e10e commit a987053
Show file tree
Hide file tree
Showing 17 changed files with 541 additions and 125 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
with:
python-version: 3.8
- name: Install package dependencies
run: pip install -e .[rdkit]
run: pip install -e .[rdkit,rxnmapper]
- name: Install sphinx dependencies
run: pip install -r docs/requirements.txt
- name: Make docs
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
with:
python-version: 3.8
- name: Install Dependencies
run: pip install -e .[dev,rdkit]
run: pip install -e .[dev,rdkit,rxnmapper]
- name: Check black
run: python -m black --check --diff --color .
- name: Check isort
Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@ It has been tested on the following systems:
+ Linux: Ubuntu 18.04.4

A Python version of 3.6, 3.7, or 3.8 is recommended.
Python versions 3.9 and above are not expected to work due to compatibility with the selected version of OpenNMT.
Python versions 3.9 and above are not expe`cted to work due to compatibility with the selected version of OpenNMT.

## Installation guide

The package can be installed from Pypi:
```bash
pip install rxn-metrics[rdkit]
pip install rxn-metrics[rdkit,rxnmapper]
```
You can leave out `[rdkit]` if RDKit is already available in your environment.
You can leave out the extra dependency `rdkit` if RDKit is already available in your environment.
Also, you can leave out the extra dependency `rxnmapper` if you don't plan on calculating the "true reactant" accuracy.

For local development, the package can be installed with:
```bash
pip install -e ".[dev,rdkit]"
pip install -e ".[dev,rdkit,rxnmapper]"
```

## Calculation of metrics
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ include_package_data = True
install_requires =
click>=8.0
pandas>=1.1.0
rxn-chem-utils>=1.1.4
rxn-chem-utils>=1.3.0
rxn-onmt-models>=1.0.0
rxn-onmt-utils>=1.0.0
rxn-utils>=1.1.9
Expand All @@ -54,6 +54,9 @@ rdkit =
# installation of RDKit
rdkit-pypi>=2021.3.2 ; python_version<"3.7"
rdkit>=2022.3.4 ; python_version>="3.7"
rxnmapper =
rxnmapper>=0.3.0
transformers<4.23.0 # Versions >=4.23.0 are not compatible with torch 1.5.1

[options.entry_points]
console_scripts =
Expand Down
35 changes: 35 additions & 0 deletions src/rxn/metrics/class_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional

from rxn.utilities.files import dump_list_to_file, iterate_lines_from_file

from .metrics_files import RetroFiles


def maybe_prepare_class_token_files(
class_tokens: Optional[int], retro_files: RetroFiles
) -> None:
"""If the model is a class-token one, create the expected src and target files.
Args:
class_tokens: the number of tokens used in the trainings.
retro_files: information on the location of files for the metrics.
"""
if class_tokens is None:
return

class_token_products = (
f"{convert_class_token_idx_for_translation_models(class_token_idx)}{line}"
for line in iterate_lines_from_file(retro_files.gt_src)
for class_token_idx in range(class_tokens)
)
class_token_precursors = (
line
for line in iterate_lines_from_file(retro_files.gt_tgt)
for _ in range(class_tokens)
)
dump_list_to_file(class_token_products, retro_files.class_token_products)
dump_list_to_file(class_token_precursors, retro_files.class_token_precursors)


def convert_class_token_idx_for_translation_models(class_token_idx: int) -> str:
return f"[{class_token_idx}]"
52 changes: 51 additions & 1 deletion src/rxn/metrics/classification_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,68 @@

from rxn.chemutils.tokenization import file_is_tokenized, tokenize_file
from rxn.onmt_utils import translate
from rxn.utilities.files import is_path_exists_or_creatable
from rxn.utilities.files import dump_list_to_file, is_path_exists_or_creatable

from .metrics_files import RetroFiles
from .tokenize_file import (
classification_file_is_tokenized,
detokenize_classification_file,
tokenize_classification_file,
)
from .utils import combine_precursors_and_products_from_files

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def maybe_classify_predictions(
classification_model: Optional[Path],
retro_files: RetroFiles,
batch_size: int,
gpu: bool,
) -> None:
"""Classify the reactions for determining the diversity metric.
Only executed if a classification model is available."""

if classification_model is None:
return

create_rxn_from_files(
retro_files.predicted_canonical,
retro_files.predicted_products_canonical,
retro_files.predicted_rxn_canonical,
)

classification_translation(
src_file=retro_files.predicted_rxn_canonical,
tgt_file=None,
pred_file=retro_files.predicted_classes,
model=classification_model,
n_best=1,
beam_size=5,
batch_size=batch_size,
gpu=gpu,
)


def create_rxn_from_files(
input_file_precursors: Union[str, Path],
input_file_products: Union[str, Path],
output_file: Union[str, Path],
) -> None:
logger.info(
f'Combining files "{input_file_precursors}" and "{input_file_products}" -> "{output_file}".'
)
dump_list_to_file(
combine_precursors_and_products_from_files(
precursors_file=input_file_precursors,
products_file=input_file_products,
),
output_file,
)


def classification_translation(
src_file: Union[str, Path],
tgt_file: Optional[Union[str, Path]],
Expand Down
3 changes: 2 additions & 1 deletion src/rxn/metrics/context_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from rxn.utilities.containers import chunker
from rxn.utilities.files import PathLike, iterate_lines_from_file

from .metrics import get_sequence_multiplier, top_n_accuracy
from .metrics import top_n_accuracy
from .metrics_calculator import MetricsCalculator
from .metrics_files import ContextFiles, MetricsFiles
from .utils import get_sequence_multiplier


class ContextMetrics(MetricsCalculator):
Expand Down
17 changes: 2 additions & 15 deletions src/rxn/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import numpy as np
from rxn.utilities.containers import chunker
from rxn.utilities.misc import get_multiplier

from .utils import get_sequence_multiplier

T = TypeVar("T")

Expand Down Expand Up @@ -161,17 +162,3 @@ def class_diversity(
}
std_dev = {i + 1: float(np.std(classes_for_n[i])) for i in range(multiplier)}
return classdiversity, std_dev


def get_sequence_multiplier(ground_truth: Sequence[T], predictions: Sequence[T]) -> int:
"""
Get the multiplier for the number of predictions by ground truth sample.
Raises:
ValueError: if the lists have inadequate sizes (possibly forwarded
from get_multiplier).
"""
n_gt = len(ground_truth)
n_pred = len(predictions)

return get_multiplier(n_gt, n_pred)
2 changes: 2 additions & 0 deletions src/rxn/metrics/metrics_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__(self, directory: PathLike):
)
self.predicted_rxn_canonical = self.directory / "predicted_rxn_canonical.txt"
self.predicted_classes = self.directory / "predicted_classes.txt"
self.gt_mapped = self.directory / "gt_mapped.txt"
self.predicted_mapped = self.directory / "predicted_mapped.txt"

@staticmethod
def reordered(path: PathLike) -> Path:
Expand Down
49 changes: 37 additions & 12 deletions src/rxn/metrics/retro_metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional

from rxn.utilities.files import PathLike, iterate_lines_from_file
from rxn.utilities.files import PathLike, iterate_lines_from_file, load_list_from_file

from .metrics import class_diversity, coverage, round_trip_accuracy, top_n_accuracy
from .metrics_calculator import MetricsCalculator
from .metrics_files import MetricsFiles, RetroFiles
from .true_reactant_accuracy import true_reactant_accuracy


class RetroMetrics(MetricsCalculator):
Expand All @@ -22,15 +22,18 @@ def __init__(
gt_products: Iterable[str],
predicted_precursors: Iterable[str],
predicted_products: Iterable[str],
predicted_classes: Optional[Iterable[str]] = None,
predicted_classes: Optional[List[str]] = None,
gt_mapped_rxns: Optional[List[str]] = None,
predicted_mapped_rxns: Optional[List[str]] = None,
):
self.gt_products = list(gt_products)
self.gt_precursors = list(gt_precursors)
self.predicted_products = list(predicted_products)
self.predicted_precursors = list(predicted_precursors)
self.predicted_classes = (
list(predicted_classes) if predicted_classes is not None else None
)

self.predicted_classes = predicted_classes
self.gt_mapped_rxns = gt_mapped_rxns
self.predicted_mapped_rxns = predicted_mapped_rxns

def get_metrics(self) -> Dict[str, Any]:
topn = top_n_accuracy(
Expand All @@ -42,14 +45,20 @@ def get_metrics(self) -> Dict[str, Any]:
cov = coverage(
ground_truth=self.gt_products, predictions=self.predicted_products
)
if self.predicted_classes:
if self.predicted_classes is not None:
classdiversity, classdiversity_std = class_diversity(
ground_truth=self.gt_products,
predictions=self.predicted_products,
predicted_classes=self.predicted_classes,
)
else:
classdiversity, classdiversity_std = {}, {}
if self.gt_mapped_rxns is not None and self.predicted_mapped_rxns is not None:
reactant_accuracy = true_reactant_accuracy(
self.gt_mapped_rxns, self.predicted_mapped_rxns
)
else:
reactant_accuracy = {}

return {
"accuracy": topn,
Expand All @@ -58,6 +67,7 @@ def get_metrics(self) -> Dict[str, Any]:
"coverage": cov,
"class-diversity": classdiversity,
"class-diversity-std": classdiversity_std,
"true-reactant-accuracy": reactant_accuracy,
}

@classmethod
Expand All @@ -68,6 +78,9 @@ def from_metrics_files(cls, metrics_files: MetricsFiles) -> "RetroMetrics":
# Whether to use the reordered files - for class token
# To determine whether True or False, we check if the reordered files exist
reordered = RetroFiles.reordered(metrics_files.predicted_canonical).exists()
mapped = (
metrics_files.gt_mapped.exists() and metrics_files.predicted_mapped.exists()
)

return cls.from_raw_files(
gt_precursors_file=metrics_files.gt_tgt,
Expand All @@ -84,11 +97,15 @@ def from_metrics_files(cls, metrics_files: MetricsFiles) -> "RetroMetrics":
),
predicted_classes_file=(
None
if not os.path.exists(metrics_files.predicted_classes)
if not metrics_files.predicted_classes.exists()
else metrics_files.predicted_classes
if not reordered
else RetroFiles.reordered(metrics_files.predicted_classes)
),
gt_mapped_rxns_file=metrics_files.gt_mapped if mapped else None,
predicted_mapped_rxns_file=(
metrics_files.predicted_mapped if mapped else None
),
)

@classmethod
Expand All @@ -99,13 +116,21 @@ def from_raw_files(
predicted_precursors_file: PathLike,
predicted_products_file: PathLike,
predicted_classes_file: Optional[PathLike] = None,
gt_mapped_rxns_file: Optional[PathLike] = None,
predicted_mapped_rxns_file: Optional[PathLike] = None,
) -> "RetroMetrics":
# to simplify because it is called multiple times.
def maybe_load_lines(filename: Optional[PathLike]) -> Optional[List[str]]:
if filename is None:
return None
return load_list_from_file(filename)

return cls(
gt_precursors=iterate_lines_from_file(gt_precursors_file),
gt_products=iterate_lines_from_file(gt_products_file),
predicted_precursors=iterate_lines_from_file(predicted_precursors_file),
predicted_products=iterate_lines_from_file(predicted_products_file),
predicted_classes=None
if predicted_classes_file is None
else iterate_lines_from_file(predicted_classes_file),
predicted_classes=maybe_load_lines(predicted_classes_file),
gt_mapped_rxns=maybe_load_lines(gt_mapped_rxns_file),
predicted_mapped_rxns=maybe_load_lines(predicted_mapped_rxns_file),
)
Loading

0 comments on commit a987053

Please sign in to comment.