In [1]:
import sys
from pathlib import Path
from typing import List

import pandas as pd

sys.path.append('.')
from utils import get_ground_truth_dict

data_dir = Path('../../data/uspto_50k')
output_dir = data_dir / 'negative_shuffle'
splits = ['test', 'valid', 'train']
device = 'cpu'

In [2]:
from torch import Tensor
import torch
import numpy as np
from rdkit import DataStructs
from tqdm import tqdm
from rdkit.Chem import AllChem, MolFromSmiles


def get_ecfp_tensor(smiles_list: List[str]):
    ecfp_list = [AllChem.GetMorganFingerprintAsBitVect(MolFromSmiles(smiles), 2) for smiles in smiles_list]
    ecfp_torch = []
    for ecfp in ecfp_list:
        array = np.zeros((0,), dtype=np.int8)
        DataStructs.ConvertToNumpyArray(ecfp, array)
        ecfp_torch.append(torch.tensor(array).bool())
    ecfp_torch = torch.stack(ecfp_torch)
    return ecfp_torch


def tanimoto_similarity(a: Tensor, rest: Tensor) -> Tensor:
    in_a = a.sum()
    in_b = rest.sum(dim=1)
    in_both = (a.unsqueeze(0) * rest).sum(dim=1)
    similarity = in_both / (in_a + in_b - in_both)
    return similarity


def get_shuffle_negatives(split: str, max_negatives_per_positive: int = 9):
    product_to_reactants = get_ground_truth_dict(data_dir, split)
    positive_df = pd.read_csv(data_dir / 'positive' / f'{split}.csv')

    product_ecfp_torch = get_ecfp_tensor(positive_df['product']).to(device)
    reactant_ecfp_torch = get_ecfp_tensor(positive_df['reactants']).to(device)

    negative_reactions = set()
    for idx, row in tqdm(positive_df.iterrows(), total=len(positive_df)):
        product_ecfp = product_ecfp_torch[idx]
        reactant_ecfp = reactant_ecfp_torch[idx]
        similarity = tanimoto_similarity(product_ecfp, product_ecfp_torch) + tanimoto_similarity(reactant_ecfp,
                                                                                                 reactant_ecfp_torch)
        similarity = similarity.cpu()
        similar_indices = torch.argsort(similarity, descending=True)

        product = row['product']
        forbidden_reactants = product_to_reactants[product]
        added_count = 0
        for similar_idx in similar_indices:
            similar_reactants = positive_df.iloc[similar_idx]['reactants']
            if similar_reactants not in forbidden_reactants and (product, similar_reactants) not in negative_reactions:
                negative_reactions.add((product, similar_reactants))
                added_count += 1
            if added_count >= max_negatives_per_positive:
                break
                
    negative_df = pd.DataFrame(list(negative_reactions), columns=['product', 'reactants'])
    negative_df['feasible'] = 0
    return negative_df


In [None]:
output_dir.mkdir(exist_ok=True, parents=True)
for split in splits:
    raw = get_shuffle_negatives(split)
    raw.to_csv(output_dir / f'{split}.csv', index=False)