In [43]:
import sys
from collections import defaultdict
from pathlib import Path
import pandas as pd

sys.path.append('.')
from utils import canonicalize_smiles, get_ground_truth_dict
from tqdm import tqdm
from rdkit.Chem import rdChemReactions, MolFromSmiles, MolToSmiles

data_dir = Path('../../data/uspto_50k')
output_dir = data_dir / 'negative_forward'
splits = ['test', 'valid', 'train']
max_n_templates = 2000

In [44]:
def load_templates_dict(split: str, n_templates: int):
    template_path = data_dir / 'forward_templates' / f'{split}.csv'
    templates = pd.read_csv(template_path)
    templates_count_dict = {row['reaction_smarts']: row['count'] for _, row in templates.iterrows()}
    if split != 'train':
        train_templates = load_templates_dict('train', 1000000000000)
        for template, count in train_templates.items():
            templates_count_dict[template] = count + templates_count_dict.get(template, 0)
    templates_count_dict = dict(sorted(templates_count_dict.items(), key=lambda x: x[1], reverse=True)[:n_templates])
    return templates_count_dict


def get_forward_negatives(split: str):
    templates = list(load_templates_dict(split, max_n_templates).keys())
    reactions = [rdChemReactions.ReactionFromSmarts(template) for template in templates]

    product_to_reactants = get_ground_truth_dict(data_dir, split)
    positive_reactants = list(pd.read_csv(data_dir / 'positive' / f'{split}.csv')['reactants'].values)

    negative_reactions = set()
    found = 0
    for reactants in tqdm(positive_reactants[:100]):
        reactant_mols = [MolFromSmiles(reactant) for reactant in reactants.split('.')]
        for rxn in reactions:
            try:
                product_list = rxn.RunReactants(reactant_mols)
                for products in product_list:
                    for product in products:
                        product_smiles = MolToSmiles(product)
                        product_smiles = canonicalize_smiles(product_smiles)
                        if product_smiles in product_to_reactants:
                            found += 1
                        if product_smiles and len(product_smiles) > 4 and reactants not in product_to_reactants.get(
                                product_smiles, set()):
                            negative_reactions.add((reactants, product_smiles))
            except:
                pass

    negative_df = pd.DataFrame(list(negative_reactions), columns=['reactants', 'product'])
    negative_df['feasible'] = 0
    return negative_df

In [None]:
output_dir.mkdir(exist_ok=True, parents=True)
for split in splits:
    df = get_forward_negatives(split)
    df.to_csv(output_dir / f'{split}.csv', index=False)