In [13]:
from rdchiral.template_extractor import extract_from_reaction
from pathlib import Path
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

data_dir = Path('../../data/uspto_50k')
output_dir = data_dir / 'forward_templates'
splits = ['test', 'valid', 'train']

In [14]:
def extract_template_from_smarts(smarts: str) -> str:
    reactants = smarts.split('>>')[0]
    product = smarts.split('>>')[-1]
    reaction_dict = {'reactants': reactants, 'products': product, '_id': 0}
    out = extract_from_reaction(reaction_dict)
    return f'{out["reactants"]}>>{out["products"]}'

def extract_templates(split: str):
    raw_path = data_dir / 'raw' / f'{split}.csv'
    raw = pd.read_csv(raw_path)
    column = raw.columns[-1]
    raw['reaction_smarts'] = raw[column].progress_apply(extract_template_from_smarts)
    reaction_smarts_count_dict = raw['reaction_smarts'].value_counts().to_dict()
    return pd.DataFrame(reaction_smarts_count_dict.items(), columns=['reaction_smarts', 'count'])

In [None]:
output_dir.mkdir(exist_ok=True, parents=True)
for split in splits:
    df = extract_templates(split)
    df.to_csv(output_dir / f'{split}.csv', index=False)