Skip to content

Commit

Permalink
Make BatchedMapper more flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
avaucher committed Jun 16, 2023
1 parent a80c70f commit 282d444
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions rxnmapper/batched_mapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
from typing import Any, Dict, Iterable, Iterator, List
from typing import Any, Dict, Iterable, Iterator, List, Optional

import pkg_resources
from rxn.utilities.containers import chunker
from rxn.utilities.files import PathLike

from .core import RXNMapper

Expand All @@ -20,10 +22,41 @@ class BatchedMapper:
def __init__(
self,
batch_size: int,
model_path: Optional[PathLike] = None,
head: int = 5,
attention_multiplier: float = 90.0,
layer: int = 10,
model_type: str = "albert",
canonicalize: bool = False,
placeholder_for_invalid: str = ">>",
):
self.mapper = RXNMapper()
"""
Args:
batch_size: batch size for inference.
model_path: path to the model directory, defaults to the model from
the original publication.
head: head related to atom mapping in the model. The default is the
one for the original publication.
attention_multiplier: attention multiplier, no need to change the default.
layer: layer, no need to change the default.
model_type: model type.
canonicalize: whether to canonicalize before predicting the atom mappings.
placeholder_for_invalid: placeholder to use in the output when there
is an issue in the prediction (number of tokens, invalid SMILES, ...).
"""
if model_path is None:
model_path = pkg_resources.resource_filename(
"rxnmapper", "models/transformers/albert_heads_8_uspto_all_1310k"
)
self.mapper = RXNMapper(
config=dict(
model_path=str(model_path),
head=head,
layers=[layer],
model_type=model_type,
attention_multiplier=attention_multiplier,
),
)
self.batch_size = batch_size
self.canonicalize = canonicalize
self.placeholder_for_invalid = placeholder_for_invalid
Expand Down

0 comments on commit 282d444

Please sign in to comment.