In [2]:
from argparse import ArgumentParser

import torch
import sklearn
import torch.nn as nn

from utils import init_featurizer, mkdir_p, get_configure, load_model, load_dataloader, predict
from get_edit import write_edits


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

def test(args):
    model_name = 'LocalRetro_%s.pth' % args['dataset']
    args['model_path'] = '../models/%s' % model_name
    args['config_path'] = '../data/configs/%s' % args['config']
    args['data_dir'] = '../data/%s' % args['dataset']
    args['result_path'] = '../outputs/raw_prediction/%s' % model_name.replace('.pth', '.txt')
    mkdir_p('../outputs')
    mkdir_p('../outputs/raw_prediction')
    
    args = init_featurizer(args)
    model = load_model(args)
    test_loader = load_dataloader(args)
    write_edits(args, model, test_loader)
    return

In [4]:
default_args = {
    'gpu': 'cuda:0',
    'dataset': 'USPTO_50K_data',
    'config': 'default_config.json',
    'batch_size': 16,
    'top_num': 100,
    'num_workers': 0,
    'mode': 'test',
    'device': torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
}
test(default_args)

Directory ../outputs already exists.
Directory ../outputs/raw_prediction already exists.
Parameters of loaded LocalRetro:
{'attention_heads': 8, 'attention_layers': 1, 'batch_size': 16, 'edge_hidden_feats': 64, 'node_out_feats': 320, 'num_step_message_passing': 6, 'AtomTemplate_n': 43, 'BondTemplate_n': 194, 'in_node_feats': 80, 'in_edge_feats': 13}
Loading previously saved test dgl graphs...
Writing test molecule batch 12/13


In [1]:
import os, sys, re
import pandas as pd
import multiprocessing
from tqdm import tqdm
from functools import partial
from collections import defaultdict
from argparse import ArgumentParser
sys.path.append('../')
    
import rdkit
from rdkit import Chem, RDLogger 
from rdkit.Chem import rdChemReactions

from utils import mkdir_p
from LocalTemplate.template_decoder import *

def get_k_predictions(test_id, args):
    raw_prediction = args['raw_predictions'][test_id]
    all_prediction = []
    class_prediction = []
    product = raw_prediction[0]
    predictions = raw_prediction[1:]
    for prediction in predictions:
        mol, pred_site, template, template_info, score = read_prediction(product, prediction, args['atom_templates'], args['bond_templates'], args['template_infos'])
        local_template = '>>'.join(['(%s)' % smarts for smarts in template.split('_')[0].split('>>')])
        decoded_smiles = decode_localtemplate(mol, pred_site, local_template, template_info)
        try:
            decoded_smiles = decode_localtemplate(mol, pred_site, local_template, template_info)
            if decoded_smiles == None or str((decoded_smiles, score)) in all_prediction:
                continue
        except Exception as e:
#                     print (e)
            continue
        all_prediction.append(str((decoded_smiles, score)))

        if args['rxn_class_given']:
            rxn_class = args['test_rxn_class'][test_id]
            if template in args['templates_class'][str(rxn_class)].values:
                class_prediction.append(str((decoded_smiles, score)))
            if len (class_prediction) >= args['top_k']:
                break

        elif len (all_prediction) >= args['top_k']:
            break
    return (test_id, (all_prediction, class_prediction))

def decode_prediction(args):   
    atom_templates = pd.read_csv('../data/%s/atom_templates.csv' % args['dataset'])
    bond_templates = pd.read_csv('../data/%s/bond_templates.csv' % args['dataset'])
    template_infos = pd.read_csv('../data/%s/template_infos.csv' % args['dataset'])
    class_test = '../data/%s/class_test.csv' % args['dataset']
    if os.path.exists(class_test):
        args['rxn_class_given'] = True
        args['templates_class'] = pd.read_csv('../data/%s/template_rxnclass.csv' % args['dataset'])
        args['test_rxn_class'] = pd.read_csv(class_test)['class']
    else:
        args['rxn_class_given'] = False 
    args['atom_templates'] = {atom_templates['Class'][i]: atom_templates['Template'][i] for i in atom_templates.index}
    args['bond_templates'] = {bond_templates['Class'][i]: bond_templates['Template'][i] for i in bond_templates.index}
    args['template_infos'] = {template_infos['Template'][i]: {'edit_site': eval(template_infos['edit_site'][i]), 'change_H': eval(template_infos['change_H'][i]), 'change_C': eval(template_infos['change_C'][i]), 'change_S': eval(template_infos['change_S'][i])} for i in template_infos.index}
    
   
    if args['model'] == 'default':
        result_name = 'LocalRetro_%s.txt' % args['dataset']
    else:
        result_name = 'LocalRetro_%s.txt' % args['model']
    
    prediction_file =  '../outputs/raw_prediction/' + result_name
    raw_predictions = {}
    with open(prediction_file, 'r') as f:
        for line in f.readlines():
            seps = line.split('\t')
            if seps[0] == 'Test_id':
                continue
            raw_predictions[int(seps[0])] = seps[1:]
        
    output_path = '../outputs/decoded_prediction/' + result_name
    output_path_class = '../outputs/decoded_prediction_class/' + result_name
    args['raw_predictions'] = raw_predictions
    # multi_processing
    result_dict = {}
    partial_func = partial(get_k_predictions, args = args)
    with multiprocessing.Pool(processes=8) as pool:
        tasks = range(len(raw_predictions))
        for result in tqdm(pool.imap_unordered(partial_func, tasks), total=len(tasks), desc='Decoding LocalRetro predictions'):
            result_dict[result[0]] = result[1]
    
        
    with open(output_path, 'w') as f1, open(output_path_class, 'w') as f2:
        for i in sorted(result_dict.keys()) :
            all_prediction, class_prediction = result_dict[i]
            f1.write('\t'.join([str(i)] + all_prediction) + '\n')
            f2.write('\t'.join([str(i)] + class_prediction) + '\n')
            print('\rDecoding LocalRetro predictions %d/%d' % (i, len(raw_predictions)), end='', flush=True)
    print ()

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
default_args={"dataset": "USPTO_50K_data",
              "model": "default",
              "top_k": 50}

decode_prediction(default_args)

NameError: name 'decode_prediction' is not defined