In [None]:
from IPython.display import clear_output, display
from PIL import Image
import json
from utils import *

chemist_name = get_user_name()
dataset = 'USPTO_50K'

samp_iter = 1
sampled_data = load_sampled_data(dataset, chemist_name, samp_iter)
accepted_templates, rejected_templates = load_fixed_templates(dataset, chemist_name, samp_iter)
remapped_rxn_dict = {}
remapped_temp_dict = {}

print ('Chemist Name: %s, Correcting %d reaction data' % (chemist_name, len(sampled_data)))
print ('Loaded %d templates' % len(accepted_templates))

In [None]:
def clean_map(rxn):
    r_mol, p_mol = [Chem.MolFromSmiles(s) for s in rxn.split('>>')]
    p_maps = [atom.GetAtomMapNum() for atom in p_mol.GetAtoms()]
    [atom.SetAtomMapNum(0) for atom in r_mol.GetAtoms() if atom.GetAtomMapNum() not in p_maps]
    return '>>'.join([Chem.MolToSmiles(m) for m in [r_mol, p_mol]])

In [None]:
## Manually check AAM
# 0: remap, 1: accept, 2: reject reaction
for i, (idx, rxn, temp, freq) in enumerate(zip(sampled_data['data_idx'], sampled_data['mapped_rxn'], sampled_data['template'], sampled_data['freq'])): # remap: reject, 1: accept, 2: reject
    if idx in remapped_rxn_dict:
        continue
    rxn = clean_map(rxn)
    r, p = rxn.split('>>')
    temp = extract_from_reaction(rxn)
    answer = '1'
    
    while True:
        if temp in accepted_templates: answer = '1'; break
        print (rxn)
        print ('Reactant: \n', r); print ('Template: \n', temp); print ('Frequency: \n', freq); 
        save_reaction(rxn)
        display(Image.open('mol.png'))
        answer = input('Correct (%d/%d)?' % (i, len(sampled_data))) 
        if answer in ['1', '2']: 
            break
        remap = input('Remap (%d/%d)...' % (i, len(sampled_data)))
        if not is_valid_mapping(remap): 
            print ('Not valid mapping!'); continue
        else: 
            r = remap
        rxn = '%s>>%s' % (r, p)
        temp = extract_from_reaction(rxn)
     
    save_reaction(rxn)
    display(Image.open('mol.png'))
    if answer == '1':
        remapped_rxn_dict[idx] = rxn
        remapped_temp_dict[idx] = temp
        accepted_templates.add(temp)
    
    clear_output(wait=True)
    
print ('Correction finished. Mapped %d reactions.' % len(remapped_rxn_dict))

In [None]:
# Sort the reaction idex before exporting
remapped_idxs, remapped_rxns, remapped_temps = [], [], []
for idx in sorted(list(remapped_temp_dict.keys())):
    remapped_idxs.append(idx)
    remapped_rxns.append(remapped_rxn_dict[idx])
    remapped_temps.append(remapped_temp_dict[idx])
df = pd.DataFrame({'data_idx': remapped_idxs, 'mapped_rxn': remapped_rxns, 'template': remapped_temps})
save_fixed_data(df, dataset, chemist_name, samp_iter)