In [None]:
from synspace.synspace import forward, retro, mannifold_retro
from synspace.utils import get_fp, remove_dups, extract_one, flatten
import os
import vdict
import numpy as np
from synspace.data import get_reactions, get_blocks
from rdkit import Chem
from rdkit.DataStructs.cDataStructs import TanimotoSimilarity
import random
import tqdm
import time


def embed_mols(mols, extra_x=None, extra_d=None, dims=8):
    from sklearn.decomposition import PCA
    from rdkit.DataStructs.cDataStructs import BulkTanimotoSimilarity

    fps = [get_fp(m) for m in mols]
    M = np.array([BulkTanimotoSimilarity(f, fps) for f in fps])
    dist_mat = 1 - M
    if extra_x is not None:
        dist_mat = np.concatenate([dist_mat, np.array(extra_x)], axis=1)
    pca = PCA(n_components=dims)
    proj_dmat = pca.fit_transform(dist_mat)
    # ensure each column is bounded from 0 to 1
    proj_dmat = (proj_dmat - proj_dmat.min(axis=0)) / (
        proj_dmat.max(axis=0) - proj_dmat.min(axis=0)
    )
    result = vdict.vdict(tol=10**10)
    for i, m in enumerate(mols):
        result[proj_dmat[i]] = (m, extra_d[i]) if extra_d is not None else m
    return result


def one_hot(x, N):
    return np.eye(N)[np.array(x)]


def multilabel(x, N):
    r = np.zeros(N)
    if x is None or len(x) == 0:
        return r
    r[np.array(x)] = 1
    return r


def embed_blocks(blocks, dims):
    extra_x = []
    extra_d = []
    mols = []
    for i, (name, r_pos) in enumerate(blocks.items()):
        for j, r_pos_i in enumerate(r_pos):
            for block in r_pos_i[:100]:
                mols.append(block)
                extra_x.append([i, j])
                extra_d.append([name, j])

    extra_x = [
        np.concatenate((one_hot(x[0], i + 1), one_hot(x[1], 2))) for x in extra_x
    ]
    return embed_mols(mols, extra_x, extra_d, dims=dims)


def qd_chemical_space(
    mol,
    steps=(1, 1),
    threshold=0.2,
    blocks=None,
    rxns=None,
    use_mannifold=None,
    strict=None,
    nblocks=25,
    num_samples=250,
    _pbar=None,
    embed_dim=32,
):
    """Explore blocks/reactions via quality (tanimoto) diversity (atoms affected)"""

    from ribs.archives import GridArchive

    if type(mol) == str:
        mol = Chem.MolFromSmiles(mol)
    mol_fp = get_fp(mol)
    natoms = mol.GetNumAtoms()
    if type(steps) == int:
        steps = (0, steps)
    if blocks is None:
        blocks = get_blocks()
    if rxns is None:
        rxns = get_reactions()
    if use_mannifold is None:
        use_mannifold = os.environ.get("POSTERA_API_KEY") is not None
    if use_mannifold:
        if _pbar:
            _pbar.set_description("⚗️Synspace Retrosynthesis (Mannifold)⚗️")
        mols, props = mannifold_retro(mol)
    else:
        if _pbar:
            _pbar.set_description("⚗️Synspace Retrosynthesis...⚗️")
        mols, props = retro(mol, rxns=rxns, strict=False if strict is None else strict)
        for _ in range(steps[0] - 1):
            to_add = []
            for m, p in zip(mols, props):
                ms, ps = retro(
                    m,
                    rxns=rxns,
                    strict=False if strict is None else strict,
                    start_props=p,
                )
                to_add.append((ms, ps))
            for m, p in to_add:
                mols.extend(m)
                props.extend(p)
                if _pbar:
                    _pbar.update(len(mols))
        mols, props = remove_dups(mols, props)
    mol_embed_dim = min(len(mols), embed_dim)
    if len(mols) > 1:
        eretro = embed_mols(mols, dims=mol_embed_dim)
    # need to get the extra_x, extra_d
    # which represent reaction, template location
    eblocks = embed_blocks(blocks, dims=embed_dim)

    # x -> set of points. First is starting mol (possibly from retro), 2 to N are forward reactions
    def _simulate(x):
        if len(mols) > 1:
            x0 = x[:mol_embed_dim]
            m = eretro[x0]
        else:
            m = mols[0]
        result = 0, ()
        for i in range(steps[1]):
            x1 = x[mol_embed_dim + embed_dim * i : mol_embed_dim + embed_dim * (i + 1)]
            m1, d1 = eblocks[x1]
            name, pos = d1
            rxn = rxns[name][0]
            reactants = [None for _ in range(rxn.GetNumReactantTemplates())]
            j = 0
            if len(reactants) > 1:
                reactants[pos] = m1
                j = reactants.index(None)
            reactants[j] = m
            match = flatten(mol.GetSubstructMatches(m))
            if not match:
                break
            p = rxn.RunReactants(reactants)
            if len(p) == 1 or (len(p) > 0 and not strict):
                m, _ = extract_one(p)
        return TanimotoSimilarity(mol_fp, get_fp(m)), multilabel(match, natoms), m

    from ribs.emitters import ImprovementEmitter
    from ribs.archives import CVTArchive
    from ribs.optimizers import Optimizer

    action_dim = mol_embed_dim + embed_dim * steps[1]
    archive = CVTArchive(10**3, [(-0.01, 1.01)] * natoms, use_kd_tree=True)
    initial_model = np.zeros(action_dim)
    emitters = [
        ImprovementEmitter(
            archive,
            initial_model,
            0.1,  # Initial step size.
            batch_size=30,
        )
        for _ in range(5)  # Create 5 separate emitters.
    ]
    optimizer = Optimizer(archive, emitters)

    start_time = time.time()
    total_itrs = 1000

    for itr in tqdm.tqdm(range(1, total_itrs + 1)):
        # Request models from the optimizer.
        sols = optimizer.ask()

        # Evaluate the models and record the objectives and BCs.
        objs, bcs, ms = [], [], []
        for x in sols:
            o, b, m = _simulate(x)
            objs.append(o)
            bcs.append(b)
            ms.append(m)

        # Send the results back to the optimizer.
        optimizer.tell(objs, bcs, metadata=ms)

        # Logging.
        if itr % 25 == 0:
            elapsed_time = time.time() - start_time
            print(f"> {itr} itrs completed after {elapsed_time:.2f} s")
            print(f"  - Archive Size: {len(archive)}")
            print(f"  - Max Score: {archive.stats.obj_max}")
    return archive

In [None]:
smi = "Cc1ccc(cc1Nc2nccc(n2)c3cccnc3)NC(=O)c4ccc(cc4)CN5CCN(CC5)C"
mol = Chem.MolFromSmiles(smi)
archive = qd_chemical_space(mol)

In [None]:
pd = archive.as_pandas()

In [None]:
pd.head(n=25)

In [None]:
e = archive.get_random_elite()
e.meta

In [None]:
archive.samples(
    5,
)

In [None]:
mol