# CLIP Mol

Make molecules that look like a given text prompt. This was built using [SELFIES](https://github.com/aspuru-guzik-group/selfies) to generate the molecules, [rdkit](https://www.rdkit.org/) to draw the molecules, [CLIP](https://github.com/openai/CLIP) to compare the images to the text prompt, and [pymoo](https://pymoo.org) to optimize the molecules' agreement with CLIP. 

Here are some examples:

### Bird
![Molecule that looks like a bird](https://raw.githubusercontent.com/whitead/clipmol/main/examples/bird.png)

### Cat
![Molecule that looks like a cat](https://raw.githubusercontent.com/whitead/clipmol/main/examples/cat.png)


### Fir Tree Animation
<details>
<summary>Click to shiow</summary>

![Time laps of molecule turning into a fir tree](https://raw.githubusercontent.com/whitead/clipmol/main/examples/christmas.gif)

</details>

## FAQ

[See FAQ](https://github.com/whitead/clipmol#faq)



In [None]:
#@title Install and Load Packages
#@markdown Exectue this cell to load packages

!pip install ftfy regex tqdm rdkit-pypi pymoo selfies
!pip install git+https://github.com/openai/CLIP.git
!apt install imagemagick

import os
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from rdkit import Chem
from rdkit.Chem import Draw
from collections import OrderedDict
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import SVG
import selfies as sf
from IPython.display import display
from rdkit.DataStructs.cDataStructs import TanimotoSimilarity
import rdkit.Chem.rdMolDescriptors
from rdkit.Chem.Draw import rdDepictor
import json
from urllib.request import urlopen
import numpy as np
import torch
import clip
from google.colab import files


IPythonConsole.ipython_useSVG = True
%matplotlib inline

# load clip model

model, preprocess = clip.load('ViT-B/32')
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

def draw(smiles, dos):
  return np.stack([
                   preprocess(Chem.Draw.MolToImage(Chem.MolFromSmiles(s), size=(input_resolution, input_resolution), options=dos).convert('RGB')) for s in smiles])
def score(smiles, text_features, dos):
  image_input = torch.tensor(draw(smiles, dos)).cuda()
  with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    image_features /= image_features.norm(dim=-1, keepdim=True)
    t = text_features.cpu().numpy()
    i = image_features.cpu().numpy()
    similarity =  t[0] @ i.T
    return -similarity
    
def similarity(smiles, ref):
    mol_list = [Chem.MolFromSmiles(x) for x in smiles]
    fps = [ rdkit.Chem.rdMolDescriptors.GetMorganFingerprintAsBitVect(x,2) for x in mol_list]
    fpr = rdkit.Chem.rdMolDescriptors.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(ref),2)
    return np.array([TanimotoSimilarity(f,fpr) for f in fps])

# used to get token counts
if False:
  # make alphabet with oversampling in important tokens
  data_url = "https://github.com/aspuru-guzik-group/selfies/raw/16a489afa70882428bc194b2b24a2d33573f1651/examples/vae_example/datasets/dataJ_250k_rndm_zinc_drugs_clean.txt"
  pd_data = pd.read_csv(data_url)
  selfies_list = [sf.encoder(s) for s in pd_data.iloc[:, 0]]

  selfies_symbol_counts = {"[nop]": 0}


  def parse(s):
      for si in s.split("[")[1:]:
          token = "[" + si
          if token in selfies_symbol_counts:
              selfies_symbol_counts[token] += 1
          else:
              selfies_symbol_counts[token] = 0


  [parse(s) for s in selfies_list]
  sorted_token_counts = list(sorted(selfies_symbol_counts.items(), key=lambda i: -i[1]))
  for p in sorted_token_counts[:10]:
      print(*p)
  # print out topic tokens
  with open('tokens.json', 'r') as f:
    f.write(json.dumps(sorted_token_counts))
  files.download('tokens.json') 

# load tokens
url = "https://raw.githubusercontent.com/whitead/clipmol/main/tokens.json"
sorted_token_counts = json.loads(urlopen(url).read())

In [None]:

#@title Run CLIP Mol


prompt = "A symmetric fir tree" #@param {type:"string"}
random_seed =  2#@param
mol_drawer = "rdkit" #@param ["rdkit", "CoordGen"]
black_and_white = True #@param
#@markdown *whether to consider color when computing agreement with prompt*
iterations = 500 #@param
mol_diversity = 2 #@param {type:"slider", min:1, max:10, step:1}
#@markdown *1 = only carbon, increase to add more types of elements/bonds*
mol_size = 100 #@param {type:"slider", min:30, max:300, step:10}

#@markdown *molecule size is number of SELFIES tokens, not number of atoms*

from pymoo.algorithms.soo.nonconvex.ga import GA
from pymoo.algorithms.moo.nsga2 import NSGA2
import pymoo.factory
from pymoo.optimize import minimize
from pymoo.visualization.scatter import Scatter
from pymoo.core.problem import Problem
from pymoo.core.callback import Callback
from pymoo.core.evaluator import Evaluator
from pymoo.core.population import Population
from IPython.display import clear_output, display

#set-up alphabet
# downselect to top M and repeat to fill up to N according to ratios
M = mol_diversity + 5
N = max(M, 30)
alphabet = list([t[0] for t in sorted_token_counts[:M]])
sum_counts = sum([t[1] for t in sorted_token_counts[:M]])
# now fill in with repeats of most common
for t,c in sorted_token_counts[:M]:
  i = int(c / sum_counts * (N - M))
  alphabet += [t] * i
alphabet.sort()
alphabet.insert(0, '[nop]')

 # carbon only?
if mol_diversity == 1:
  alphabet = ['[nop]', '[C]', '[Ring1]', '[Branch1]', '[Branch2]', '[Ring2]']
vocab_itos = {i: s for i, s in enumerate(alphabet)}

def ints2smiles(ints):
  selfies = sf.encoding_to_selfies(ints, vocab_itos, 'label')
  return sf.decoder(selfies)

class CLIPMol(Problem):
    def __init__(self, L, text, black_white=True):
        super().__init__(n_var=L, n_obj=3, n_constr=0, xl=[0] * L, xu=[len(alphabet) - 1] * L)
        text_tokens = clip.tokenize([text]).cuda()
        self.L = L
        self.dos = Chem.Draw.MolDrawOptions()
        if black_white:
          self.dos.useBWAtomPalette()
        
        with torch.no_grad():
          self.text_features = model.encode_text(text_tokens).float()
        self.text_features /= self.text_features.norm(dim=-1, keepdim=True)

    def _evaluate(self, x, out, *args, **kwargs):
        s = [ints2smiles(xi) for xi in x]
        scores = score(s, self.text_features, dos=self.dos)
        i = np.argmin(scores)
        lengths = [max(-50, -len(si)) for si in s]
        d = similarity(s, s[i])
        F = np.column_stack((d, lengths, scores))
        out["F"] = F
class DrawCallback(Callback):

    def __init__(self, period=10, display=9) -> None:
        super().__init__()
        self.calls = 0  
        self.period = period
        self.results = []
        self.display = display
        self.scores = []

    def notify(self, algorithm):
        self.calls += 1
        l = np.argmin(algorithm.pop.get("F")[:, -1])
        best = ints2smiles(algorithm.pop.get("X")[l])
        self.results.append(best)
        self.scores.append(algorithm.pop.get("F")[l, -1])
        if self.calls % self.period == 0:
          i = np.argsort(algorithm.pop.get("F")[:,-1])
          # downsample
          stride = len(algorithm.pop) // self.display
          i = i[:self.display * stride:stride]
          mols = [Chem.MolFromSmiles(ints2smiles(s)) for s in algorithm.pop.get("X")[i]]
          clear_output(wait=True)
          display('Iteration {} / {}'.format(self.calls, iterations))
          display(Draw.MolsToGridImage(mols, molsPerRow=3, useSVG=True, 
                               legends = ['Similarity: {:.3f}'.format(-s) for s in algorithm.pop.get("F")[i, -1]],
                               subImgSize=(input_resolution, input_resolution), ))
          
# fancy coordinates slow but increases diversity and is usually worse
rdDepictor.SetPreferCoordGen(False if mol_drawer == 'rdkit' else True)
my_problem = CLIPMol(mol_size, prompt, black_and_white)
c = DrawCallback(10)
pop_size = 150 if mol_drawer == 'rdkit' else 50

algorithm = NSGA2(pop_size=pop_size,
                  sampling=pymoo.factory.get_sampling("int_random"),
                  #crossover=pymoo.factory.get_crossover("int_sbx",eta=N / M * 3),
                  crossover=pymoo.factory.get_crossover("int_exp"),
                  mutation=pymoo.factory.get_mutation("int_pm", eta=N / M * 3))

res = minimize(my_problem,
               algorithm,
               ('n_gen', iterations),
               seed=random_seed,
               callback=c,
               verbose=False)


In [None]:
#@title Save Output Image

import os, glob

for f in glob.glob('*.png'):
  os.remove(f)

dos = Chem.Draw.MolDrawOptions()
if black_and_white:
  dos.useBWAtomPalette()

last = ''
index = 0
for s, r in zip(c.scores, c.results):
  if r != last:
    f = '{:03d}.png'.format(index)
    m = Chem.Draw.MolToImage(Chem.MolFromSmiles(r), size=(512, 512), options=dos, legend=prompt + '\n\nScore: {:.3f}'.format(-s))
    m.save(f)
    index += 1
  last = r
f = '{:03d}.png'.format(index)
r = c.results[np.argmin(c.scores)]
m = Chem.Draw.MolToImage(Chem.MolFromSmiles(r), size=(512, 512), options=dos, legend=prompt + '\n\nScore: {:.3f}'.format(-s))
m.save(f)
files.download(f) 
print(r)


In [None]:
#@title Save GIF of Search
!convert -delay 25 -loop 1 *.png mol.gif
files.download('mol.gif') 

In [None]:
#@title ZINC
#@markdown This searches through commercially available molecules in [ZINC20](https://zinc20.docking.org) to find a molecule that matches the prompt.
#@markdown *ZINC20 is usually overloaded, so this code may only run sometimes.*
tranches = pd.read_csv('https://gist.githubusercontent.com/whitead/f47887e45bbd2f38332182d2d422da6b/raw/a3948beac9b9034dab432b697c5ec238503ac5d0/tranches.txt')
def get_mol_batch(batch_size = 32):
  for t in tranches.values:
    print('On tranch', t[0])
    d = pd.read_csv(t[0], sep=' ')    
    for i in range(len(d) // batch_size):
      yield d.iloc[i * batch_size:(i + 1) * batch_size, 0].values

text_tokens = clip.tokenize([prompt]).cuda()
with torch.no_grad():
    text_features = model.encode_text(text_tokens).float()
best = None, 100
search_count = 10000#@param
#@markdown *number of molecules to check*
for smiles in get_mol_batch():  
  s = score(smiles, text_features, dos=dos)
  if min(s) < best[1]:
    best = smiles[np.argmin(s)], min(s)
    IPythonConsole.display.display(Chem.Draw.MolToImage(Chem.MolFromSmiles(best[0])))
    print(str(best), search_count, 'remaining')    
  search_count -= 1
  if search_count <= 0:
    break