In [1]:
from torchdrug import data, datasets, utils
%matplotlib inline
import os
proxy='http://127.0.0.1:10809'
os.environ['https_proxy'] = proxy
os.environ['http_proxy'] = proxy

In [2]:
reaction_dataset = datasets.USPTO50k(".cache/molecule-datasets/",
                                     node_feature="center_identification",
                                     kekulize=True)
synthon_dataset = datasets.USPTO50k(".cache/molecule-datasets/", as_synthon=True,
                                    node_feature="synthon_completion",
                                    kekulize=True)

Loading .cache/molecule-datasets/data_processed.csv: 100%|██████████| 50017/50017 [00:00<00:00, 133721.12it/s]
Constructing molecules from SMILES: 100%|██████████| 50016/50016 [02:33<00:00, 326.68it/s]
  scale = scale[-1] // scale
Computing reaction centers: 100%|██████████| 50016/50016 [01:26<00:00, 581.07it/s]
Loading .cache/molecule-datasets/data_processed.csv: 100%|██████████| 50017/50017 [00:00<00:00, 116312.59it/s]
Constructing molecules from SMILES: 100%|██████████| 50016/50016 [02:48<00:00, 296.83it/s]
Computing synthons: 100%|██████████| 50016/50016 [03:47<00:00, 219.55it/s]


In [None]:
import pandas as pd
df = pd.read_csv('.cache/molecule-datasets/data_processed.csv')
df

In [8]:
from torchdrug.utils import plot

for i in range(5):
    sample = reaction_dataset[i]
    reactant, product = sample["graph"]
    reactants = reactant.connected_components()[0]
    products = product.connected_components()[0]
    plot.reaction(reactants, products)

In [6]:
for i in range(3):
    sample = synthon_dataset[i]
    reactant, synthon = sample["graph"]
    plot.reaction([reactant], [synthon])

In [9]:
import torch

torch.manual_seed(1)
reaction_train, reaction_valid, reaction_test = reaction_dataset.split()
torch.manual_seed(1)
synthon_train, synthon_valid, synthon_test = synthon_dataset.split()

In [10]:
from torchdrug import core, models, tasks

reaction_model = models.RGCN(input_dim=reaction_dataset.node_feature_dim,
                    hidden_dims=[256, 256, 256, 256, 256, 256],
                    num_relation=reaction_dataset.num_bond_type,
                    concat_hidden=True)
reaction_task = tasks.CenterIdentification(reaction_model,
                                           feature=("graph", "atom", "bond"))

In [11]:
reaction_optimizer = torch.optim.Adam(reaction_task.parameters(), lr=1e-3)
reaction_solver = core.Engine(reaction_task, reaction_train, reaction_valid,
                              reaction_test, reaction_optimizer,
                              gpus=[0], batch_size=128)
reaction_solver.train(num_epoch=50)
reaction_solver.evaluate("valid")
reaction_solver.save(".cache/g2gs_reaction_model.pth")

16:14:08   Preprocess training set
16:14:11   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:11   Epoch 0 begin
16:14:12   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:12   accuracy: 0.015625
16:14:12   cross entropy: 3.89416
16:14:21   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:21   accuracy: 0.5625
16:14:21   cross entropy: 1.28657
16:14:29   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:29   accuracy: 0.679688
16:14:29   cross entropy: 1.0165
16:14:37   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:37   accuracy: 0.671875
16:14:37   cross entropy: 0.963166
16:14:38   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:38   Epoch 0 end
16:14:38   duration: 29.33 secs
16:14:38   speed: 10.43 batch / sec
16:14:38   ETA: 23.95 mins
16:14:38   max GPU memory: 675.8 MiB
16:14:38   ------------------------------
16:14:38   average accuracy: 0.640388
16:14:38   average cross entropy: 1.36998
16:14:38   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:38   Epoch 1 begin
16:14:46   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:14:46   accuracy: 0.703125
16:14:46   cross e

In [None]:
batch = []
reaction_set = set()
for sample in reaction_valid:
    if sample["reaction"] not in reaction_set:
        reaction_set.add(sample["reaction"])
        batch.append(sample)
        if len(batch) == 4:
            break
batch = data.graph_collate(batch)
batch = utils.cuda(batch)
result = reaction_task.predict_synthon(batch)

In [None]:
def atoms_and_bonds(molecule, reaction_center):
    is_reaction_atom = (molecule.atom_map > 0) & \
                       (molecule.atom_map.unsqueeze(-1) == \
                        reaction_center.unsqueeze(0)).any(dim=-1)
    node_in, node_out = molecule.edge_list.t()[:2]
    edge_map = molecule.atom_map[molecule.edge_list[:, :2]]
    is_reaction_bond = (edge_map > 0).all(dim=-1) & \
                       (edge_map == reaction_center.unsqueeze(0)).all(dim=-1)
    atoms = is_reaction_atom.nonzero().flatten().tolist()
    bonds = is_reaction_bond[node_in < node_out].nonzero().flatten().tolist()
    return atoms, bonds

products = batch["graph"][1]
reaction_centers = result["reaction_center"]

for i, product in enumerate(products):
    true_atoms, true_bonds = atoms_and_bonds(product, product.reaction_center)
    true_atoms, true_bonds = set(true_atoms), set(true_bonds)
    pred_atoms, pred_bonds = atoms_and_bonds(product, reaction_centers[i])
    pred_atoms, pred_bonds = set(pred_atoms), set(pred_bonds)
    overlap_atoms = true_atoms.intersection(pred_atoms)
    overlap_bonds = true_bonds.intersection(pred_bonds)
    atoms = true_atoms.union(pred_atoms)
    bonds = true_bonds.union(pred_bonds)

    red = (1, 0.5, 0.5)
    blue = (0.5, 0.5, 1)
    purple = (1, 0.5, 1)
    atom_colors = {}
    bond_colors = {}
    for atom in atoms:
        if atom in overlap_atoms:
            atom_colors[atom] = purple
        elif atom in pred_atoms:
            atom_colors[atom] = red
        else:
            atom_colors[atom] = blue
    for bond in bonds:
        if bond in overlap_bonds:
            bond_colors[bond] = purple
        elif bond in pred_bonds:
            bond_colors[bond] = red
        else:
            bond_colors[bond] = blue

    plot.highlight(product, atoms, bonds, atom_colors, bond_colors)