<a href="https://colab.research.google.com/github/weathon/3d2smile/blob/main/ITM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install rdkit deepsmiles
!pip3 install torchinfo
!pip install tqdm boto3 requests regex sentencepiece sacremoses huggingface_hub
!wget http://file.weasoft.com/80k.csv
!pip install transformers -U

Collecting rdkit
  Downloading rdkit-2023.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.5/30.5 MB[0m [31m48.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting deepsmiles
  Downloading deepsmiles-1.0.1-py2.py3-none-any.whl (12 kB)
Installing collected packages: deepsmiles, rdkit
Successfully installed deepsmiles-1.0.1 rdkit-2023.9.2
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0
Collecting boto3
  Downloading boto3-1.34.0-py3-none-any.whl (139 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.3/139.3 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m11.7 MB/s[0

In [4]:
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
import deepsmiles
import numpy as np
import pylab
converter = deepsmiles.Converter(rings=True, branches=True)
def deepsmiles_to_img(ds):
   img = np.array(Draw.MolToImage(Chem.MolFromSmiles(converter.decode(ds)), size=(400,400)).convert("L", dither=None).convert("RGB"))
   img = np.where(img<253, 0, 1) * img
   return img

def smiles_to_img(smiles):
  return deepsmiles_to_img(converter.encode(smiles))
import torch
if torch.cuda.is_available():
  device = torch.device('cuda:0')
else:
  device = torch.device('cpu')

In [5]:
import pandas
csv = pandas.read_csv("80k.csv")

smiles_arr = []
for smiles in csv['canonicalsmiles']:
  smiles_arr.append(converter.encode(smiles))

In [6]:
smiles_encoder = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')

Downloading: "https://github.com/huggingface/pytorch-transformers/zipball/main" to /root/.cache/torch/hub/main.zip


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [7]:
import torchvision
image_encoder = torchvision.models.swin_s(weights='DEFAULT')

Downloading: "https://download.pytorch.org/models/swin_s-5e29d889.pth" to /root/.cache/torch/hub/checkpoints/swin_s-5e29d889.pth
100%|██████████| 190M/190M [00:01<00:00, 137MB/s]


In [8]:
import torchinfo
image_encoder.norm = torch.nn.Identity()
image_encoder.permute = torch.nn.Identity()
image_encoder.avgpool = torch.nn.Identity()
image_encoder.flatten = torch.nn.Flatten(-3, -2)

image_encoder.head = torch.nn.Identity()
# torchinfo.summary(image_encoder, input_size=(1, 3, 400, 400))

In [9]:
image_encoder(torch.rand(1, 3, 400, 400)).shape

torch.Size([1, 169, 768])

In [10]:
smiles_encoder.pooler = torch.nn.Identity()
# torchinfo.summary(smiles_encoder, input_data=torch.ones(1, 128, dtype=torch.int32))

In [11]:
smiles_encoder(torch.zeros(1,128, dtype=torch.int32)).last_hidden_state.shape

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


torch.Size([1, 128, 768])

In [12]:
chars = list(set("".join(smiles_arr)))
tokens = dict.fromkeys(chars)
for i, char in enumerate(chars):
  tokens[char] = i
reversed_mapping = {}
for i, char in enumerate(chars):
  reversed_mapping[i] = char
for i, smiles in enumerate(smiles_arr):
  smiles_arr[i] = [tokens[char] for char in smiles]

In [36]:
# https://youtu.be/ug8YvZOjOCE?t=2692
class CL(torch.nn.Module):
  def __init__(self, maxlen):
    super().__init__()
    self.image_encoder = image_encoder
    self.smiles_encoder = smiles_encoder
    self.smiles_proj = torch.nn.Linear(768, 512)
    self.pos1 = torch.nn.Embedding(13*13, 768)
    self.pos2 = torch.nn.Embedding(maxlen, 768)
    self.modal = torch.nn.Embedding(2, 768)
    self.i_begin = torch.nn.Embedding(1, 768)
    self.t_begin = torch.nn.Embedding(1, 768)

  def forward(self, image, smiles):
    # print(self.i_begin(torch.tensor(0)).unsqueeze(0).unsqueeze(0).repeat(image.shape[0],1,1).shape)
    image_embedding = self.image_encoder(image)
    smiles_embedding = self.smiles_encoder(smiles, attention_mask=(smiles!=30)).last_hidden_state
    pos_image = self.pos1(torch.arange(13*13))
    m_i = self.modal(torch.zeros(image_embedding.shape[1], dtype=torch.int32))
    image_embedding = image_embedding + pos_image + m_i
    pos_txt = self.pos2(torch.arange(smiles_embedding.shape[1]))
    m_t = self.modal(torch.ones(smiles_embedding.shape[1], dtype=torch.int32))
    smiles_embedding = smiles_embedding + pos_txt + m_t
    seq = torch.cat([self.i_begin(torch.tensor(0)).unsqueeze(0).unsqueeze(0).repeat(image.shape[0],1,1), image_embedding, self.i_begin(torch.tensor(0)).unsqueeze(0).unsqueeze(0).repeat(image.shape[0],1,1), smiles_embedding], dim=1)
    print(seq.shape)
CL(512)(torch.rand(2, 3, 400, 400), torch.zeros(2,128, dtype=torch.int32))

torch.Size([2, 1, 768])
torch.Size([2, 299, 768])
