In [None]:
import os
import re
from pathlib import Path
from typing import Set, List

In [None]:
from paragraph2actions.action_string_converter import ReadableConverter
from paragraph2actions.analysis import (
    full_sentence_accuracy, original_bleu, partial_accuracy, levenshtein_similarity
)

In [None]:
from smiles2actions.utils import ReactionEquation, load_list_from_file, detokenize_smiles

### Useful functions

In [None]:
def expected_placeholders_for_src(src_line: str) -> Set[str]:
    """function to get the expected placeholders for the reaction equations"""
    smiles = detokenize_smiles(src_line)
    reaction_equation = ReactionEquation.from_string(reaction_string=smiles, fragment_bond='~')
    expected_precursors = [f'${index + 1}$' for index in range(len(reaction_equation.reactants))]
    expected_products = [f'$-{index + 1}$' for index in range(len(reaction_equation.products))]
    return set(expected_precursors + expected_products)

In [None]:
converter = ReadableConverter(separator=' ; ', end_mark='')

In [None]:
def validity(expected_placeholders_lists: List[Set[str]], preds: List[str]) -> float:
    """Calculate the validity of predictions.

    Will check:
    1) The compound tokens for the SMILES are present
    2) valid conversion to actions
    ."""
    assert len(expected_placeholders_lists) == len(preds)
    n_samples = len(preds)
    valid_samples = 0
    for placeholders, pred in zip(expected_placeholders_lists, preds):
        valid = True
        # Check that all the expected placeholders are there
        for placeholder in placeholders:
            if not re.search(re.escape(placeholder), pred):
                valid = False
        # Check that the actions can be converted
        try:
            converter.string_to_actions(pred)
        except Exception:
            valid = False
        if valid:
            valid_samples += 1
    return valid_samples / n_samples

### File location

In [None]:
s2a_dir = Path(os.environ['S2A_PAPER_DATA_DIR'])
src_file = str(s2a_dir / 'src-test.txt')
tgt_file = str(s2a_dir / 'tgt-test.txt')

In [None]:
transformer_file = str(s2a_dir / 'transformer_test.txt')
bart_file = str(s2a_dir / 'bart_test.txt')
nn_file = str(s2a_dir / 'nn_test.txt')
random_file = str(s2a_dir / 'random_test.txt')
random_same_length_file = str(s2a_dir / 'random_same_smiles_length_test.txt')

### Computation

In [None]:
models = [
    (tgt_file, 'ground truth'),
    (transformer_file, 'transformer'),
    (bart_file, 'bart'),
    (nn_file, 'nearest-neighbor'),
    (random_file, 'random'),
    (random_same_length_file, 'random (same SMILES size)'),
]

In [None]:
model_names = [model[1] for model in models]
model_paths = [model[0] for model in models]

In [None]:
truth = load_list_from_file(tgt_file)
data = [load_list_from_file(model_path) for model_path in model_paths]

In [None]:
src = load_list_from_file(src_file)
expected_placeholders = [expected_placeholders_for_src(src_line) for src_line in src]

In [None]:
for model_name, pred in zip(model_names, data):
    print(model_name)
    print(' - validity', validity(expected_placeholders, pred))
    print(' - full-sentence accuracy', full_sentence_accuracy(truth, pred))
    print(' - original BLEU', original_bleu(truth, pred))
    print(' - Levenshtein', levenshtein_similarity(truth, pred))
    print(' - 100% accuracy', partial_accuracy(truth, pred, 1.0))
    print(' - 90% accuracy', partial_accuracy(truth, pred, 0.9))
    print(' - 75% accuracy', partial_accuracy(truth, pred, 0.75))
    print(' - 50% accuracy', partial_accuracy(truth, pred, 0.5))