<a href="https://colab.research.google.com/github/weathon/3d2smile/blob/main/CLIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install rdkit deepsmiles

Collecting deepsmiles
  Downloading deepsmiles-1.0.1-py2.py3-none-any.whl (12 kB)
Installing collected packages: deepsmiles
Successfully installed deepsmiles-1.0.1


In [8]:
!pip3 install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
!pip install tqdm boto3 requests regex sentencepiece sacremoses huggingface_hub

In [None]:
!pip install transformers -U

In [None]:
!wget http://file.weasoft.com/80k.csv

In [1]:
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
import deepsmiles
import numpy as np
import pylab
converter = deepsmiles.Converter(rings=True, branches=True)
def deepsmiles_to_img(ds):
   img = np.array(Draw.MolToImage(Chem.MolFromSmiles(converter.decode(ds)), size=(400,400)).convert("L", dither=None).convert("RGB"))
   img = np.where(img<253, 0, 1) * img
   return img

def smiles_to_img(smiles):
  return deepsmiles_to_img(converter.encode(smiles))

In [2]:
import pandas
csv = pandas.read_csv("80k.csv")

In [3]:
# prompt: convert all the data in canonicalsmiles column using converter.encode function and store them in a new array

smiles_arr = []
for smiles in csv['canonicalsmiles']:
  smiles_arr.append(converter.encode(smiles))


In [4]:
import torch
import torchvision

In [47]:
image_encoder = torchvision.models.efficientnet_v2_l(weights="DEFAULT")

In [10]:
smiles_encoder = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-large-uncased')

Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_main


In [11]:
chars = list(set("".join(smiles_arr)))

In [12]:
tokens = dict.fromkeys(chars)
for i, char in enumerate(chars):
  tokens[char] = i

In [13]:
reversed_mapping = {}
for i, char in enumerate(chars):
  reversed_mapping[i] = char

In [14]:
for i, smiles in enumerate(smiles_arr):
  smiles_arr[i] = [len(chars)]+[tokens[char] for char in smiles]

In [None]:
image_encoder = image_encoder.features

In [54]:
image_encoder.classifier = torch.nn.Linear(1280, 1024)

In [None]:
import torchinfo
torchinfo.summary(image_encoder, (1, 3, 224, 224))

In [53]:
import torchinfo
smiles_encoder(torch.tensor([[1,2,3]])).last_hidden_state[:,0,:].flatten()

torch.Size([1, 1024])

In [95]:
CEL = torch.nn.CrossEntropyLoss()
class CL(torch.nn.Module):
  def __init__(self, image_encoder, smiles_encoder):
    super().__init__()
    self.image_encoder = image_encoder
    self.smiles_encoder = smiles_encoder
    self.smiles_proj = torch.nn.Linear(1024, 1024)
    # https://discuss.pytorch.org/t/how-could-i-create-a-module-with-learnable-parameters/28115
    self.t = torch.nn.Parameter(torch.randn(1))
    self.t.requires_grad = True
  def forward(self, image, smiles):
    image_embedding = self.image_encoder(image)
    image_embedding = torch.flatten(image_embedding, start_dim=2, end_dim=3)
    smiles_embedding = self.smiles_encoder(smiles).last_hidden_state[:,0,:]
    smiles_embedding = self.smiles_proj(smiles_embedding)
    n = smiles_embedding.shape()[0]
    a = torch.broadcast_to(image_embedding, (n, n, -1))
    b = torch.broadcast_to(smiles_embedding, (n, n, -1)).permute(1,0,2)
    logits = torch.nn.functional.cosine_similarity(a, b, dim=-1) * torch.exp(self.t)
    return logits

In [113]:
reversed_mapping[29] = ""
reversed_mapping[30] = ""

In [None]:
model = CL(image_encoder, smiles_encoder)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0003)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99987)

for EPOCH in range(30):
  n = 5
  for i in range(0, len(smiles_arr), n):
    images_list = []
    smiles_list = []
    maxlen = 0
    for j in range(n):
        image = deepsmiles_to_img("".join([reversed_mapping[_] for _ in smiles_arr[i+j]]))
        image = torch.tensor(image).permute(2,0,1)
        images_list.append(image)
        smiles_list.append(smiles_arr[i+j])
        maxlen = max(maxlen, len(smiles_arr[i+j]))
    for j in range(n):
      smiles_list[j] += [30] * (maxlen - len(smiles_list[j]))

    images_list = torch.stack(images_list).to(torch.float32)
    smiles_list = torch.tensor(smiles_list, dtype=torch.int16)
    print(images_list.shape)
    print(smiles_list.shape)

    optimizer.zero_grad()
    logits = model(images_list, smiles_list)

    labels = torch.arange(n)
    loss_i = CEL(logits, labels)
    loss_t = CEL(logits.T, labels)
    loss = (loss_i + loss_t)/2
    loss.backward()
    optimizer.step()

torch.Size([5, 3, 400, 400])
torch.Size([5, 24])


In [93]:
torch.tensor(_).unsqueeze(0).shape

torch.Size([1, 400, 400, 3])