In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import math
import copy
import pickle

from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torch.utils.data.dataset import Subset
from sklearn.model_selection import KFold
from torch.nn.utils.rnn import pad_sequence

import matplotlib.pyplot as plt

print(torch.__version__)
print(torch.cuda.is_available())

# read symbol dictionary
the symbol dictionary is saved as 'OdorCode-40 Symbol Dictionary' by running 'datasetMake_pretrain.ipynb'

In [None]:
LIMIT_SMILES_LENGTH = 100

f = open('CHEMBL/OdorCode-40 Symbol Dictionary', 'rb') 
[symbol_ID, ID_symbol, sID] = pickle.load(f)
f.close()

PAD_ID = 0
CLS_ID = 1
BOS_ID = 2
EOS_ID = 3
MSK_ID = 4

# some functions
(1) smiles_str2smiles: translate a SMILES to a list of symbols ID

(2) smiles2smiles_str: translate a list of symbols ID to a SMILES

(3) masking: mask SMILES that are input to the second encoder

In [None]:
#----------------------------------#
#           smiles_str2smiles      #
#----------------------------------#
# transpose smiles string to the list of IDs 

max_length_symbol = max([len(s) for s in ID_symbol])

def smiles_str2smiles(smiles_str, flag=False): 
  "smiles を記号の列に変換（長さ2のNaなどの元素記号も1つのindexに変換）"

  smiles = []
  i=0
  while i < len(smiles_str):
    NotFindID = True
    for j in range(max_length_symbol,0,-1) :
      if i+j <= len(smiles_str) and smiles_str[i:i+j] in symbol_ID: 
        smiles.append(symbol_ID[smiles_str[i:i+j]])
        i += j-1 
        NotFindID = False
        break
    if NotFindID:
      # print('something wrong on converting smiles_str to smiles')
      break
    i += 1
  return smiles

#----------------------------------#
#           smiles2smiles_str      #
#----------------------------------#
def smiles2smiles_str(smiles): 
  smiles_str = ''
  for id in smiles:
    smiles_str += ID_symbol[id]
  return smiles_str


#----------------------------------#
#             masking              #
#----------------------------------#
MaskRate = 0.1   
def masking(smiles):
  smiles_tmp = []
  for s in smiles:
    p = random.random()
    if p<MaskRate:
      smiles_tmp = smiles_tmp + [MSK_ID]
    else:
      smiles_tmp = smiles_tmp + [s]
  return smiles_tmp

# read data for pretraining
read pairs of input and target SMILES from file 'CHEMBL/OdorCode-40 Pretrain MLM_data' which is obtained by running 'datasetMake_pretrain.ipynb'

In [None]:
import pickle
f = open('CHEMBL/OdorCode-40 Pretrain MLM_data','rb')
[canonical_smiles_list, smiles_list] = pickle.load(f) 
f.close()

print('Sample size : ', len(smiles_list))

# model
2-encoder model for pre-training

In [None]:
InitRange = 0.1
NumToken = sID

MaskRate = 0.1

#--------------------------------------------------------------------------------
class PositionalEncoder(nn.Module):

    def __init__(self, d_model, max_len=2048):  # d_model: dimensional of embeddings
        super().__init__()

        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) #.transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]


#----------------------------------------------------------------------------
class SymbolEncoder(nn.Module):
    def __init__(self, num_token, d_model):  
        super().__init__()
        self.d_model = d_model
        self.embed = nn.Embedding(num_token, d_model, padding_idx=PAD_ID)
        self.embed.weight.data.uniform_(-InitRange, InitRange)  # embedding init

    def forward(self, src):
        src = self.embed(src) * math.sqrt(self.d_model)
        return src
#----------------------------------------------------------------------------
class MyTransformerEncoder(nn.Module):
    def __init__(self, d_model, num_head, d_hidden):
        super().__init__()

        encoder_layers = nn.TransformerEncoderLayer(d_model, num_head, dim_feedforward=d_hidden, norm_first = NormFirst, activation=Activation, dropout=Dropout, batch_first=True)
        encoder_norm = nn.LayerNorm(d_model)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, NumLayers, norm=encoder_norm)     

    def forward(self, x, padding_mask):
        x = self.transformer_encoder(x, src_key_padding_mask = padding_mask)
        return x

#--------------ここから--------------------------------------------------------------
class MLM(nn.Module):
    def __init__(self):
        super().__init__()

        self.drop1 = nn.Dropout(p=Dropout)
        self.drop2 = nn.Dropout(p=Dropout)
        self.positional_encoder = PositionalEncoder(DimEmbed)
        self.symbol_encoder = SymbolEncoder(NumToken, DimEmbed) 
        self.smiles_encoder1 = MyTransformerEncoder(DimEmbed, NumHead, DimTfHidden)
        self.smiles_encoder2 = MyTransformerEncoder(DimEmbed, NumHead, DimTfHidden)
        self.fnn = nn.Linear(DimEmbed, NumToken)

    def forward(self, canonical_smiles, smiles_masked, flag):
        # flag is used to inditifiy whether add molecular embedding or not
        cls = torch.ones(canonical_smiles.size(0),1,dtype=torch.long).to(device)

        # add cls before canonical_smiles、compute padding_mask1 then
        # symbol_encoding, add positional_encoding (include dropout) 
        # molecular embedding is obtained by smiles_encoder1 
        if flag:
          x1 = torch.concat((cls,canonical_smiles),dim=1)  # add cls 
          padding_mask1 = (x1 == PAD_ID)
          x1 = self.drop1(self.symbol_encoder(x1) + self.positional_encoder(x1)) 
          x1 = self.smiles_encoder1(x1, padding_mask1)
          embed = x1[:,0,:]

        # add cls before smiles_masked, compute padding_mask2 then
        # symbol_encoding ,
        # if flag == True use molecular embedding obtained from smiles_encoder1 to replace the first token in the inputs
        # add positional_encoding (include dropout) 
        # return the outputs of smiles_encoder2 
        x2 = torch.concat((cls, smiles_masked),dim=1)  # add cls 
        padding_mask2 = (x2 == PAD_ID) 
        x2 = self.symbol_encoder(x2)
        if flag:
          x2[:,0,:] = embed  # use molecular embedding obtained from smiles_encoder1 to replace the first token in the inputs
        x2 = self.drop2(x2 + self.positional_encoder(x2))
        y = self.smiles_encoder2(x2, padding_mask2)
        y = self.fnn(y[:,1:,:]) 
        return y


# function for conducting experiment
The first encoder will be saved to file.

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'mean')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main_program(flag=True):

  global list_train_loss
  global list_train_acc
  global list_test_loss
  global list_test_acc

  TrainSize = int(TotalSize*4/5)
  TestSize = TotalSize-TrainSize

  all_idx_list = list(range(TotalSize))
  train_idx_list, test_idx_list = torch.utils.data.random_split(all_idx_list, [TrainSize, TestSize])

  train_dataloader = DataLoader(train_idx_list, batch_size, shuffle=True)
  test_dataloader = DataLoader(test_idx_list, batch_size, shuffle=True) 
  
  #====================== model init, set optimizer ==============
  model = MLM().to(device)

  optimizer = optim.Adam(model.parameters(), lr=LearningRate)

  list_train_loss = []
  list_train_acc = []
  list_test_loss = []
  list_test_acc = []

  pf_symbol_en = 'OdorCode-40 symbol_encoder D'+str(DimEmbed)+'.Hidden'+str(DimTfHidden)+'.Head'+str(NumHead)+'.L'+str(NumLayers)+'.R'+str(MaskRate)+'.S'+str(TotalSize)
  pf_smiles_en = 'OdorCode-40 smiles_encoder D'+str(DimEmbed)+'.Hidden'+str(DimTfHidden)+'.Head'+str(NumHead)+'.L'+str(NumLayers)+'.R'+str(MaskRate)+'.S'+str(TotalSize)

  for epoch in range(1,NumEpoch+1):

    #---------------- train step -----------------
    sample_num = 0
    total_loss = 0
    num_mask = 0
    num_notmp = 0
    num_success_mask = 0
    num_success_notmp = 0

    model.train()

    for idxs in train_dataloader:
      canonical_smiles = pad_sequence([torch.tensor([BOS_ID]+canonical_smiles_list[idx]+[EOS_ID]) for idx in idxs], batch_first=True).to(device)
      target_smiles = pad_sequence([torch.tensor([BOS_ID]+smiles_list[idx]+[EOS_ID]) for idx in idxs], batch_first=True).to(device)
      masked_smiles = pad_sequence([torch.tensor([BOS_ID]+masking(smiles_list[idx])+[EOS_ID]) for idx in idxs], batch_first=True).to(device)

      optimizer.zero_grad()

      estimated_smiles_prob = model(canonical_smiles, masked_smiles, flag)

      k1 = estimated_smiles_prob.size(0)*estimated_smiles_prob.size(1)
      k2 = target_smiles.size(0)*target_smiles.size(1)
      if k1 != k2:
        print('?????')
        exit()
      
      loss = criterion(estimated_smiles_prob.view(k1,-1), target_smiles.view(k2)) 

      loss.backward()

      optimizer.step()

      sample_num += len(idxs)
      total_loss += loss.item()*len(idxs)  

      # count the number of symbols recovery succcessy 
      estimated_smiles = torch.argmax(estimated_smiles_prob, dim=2)
      equal_element = (target_smiles == estimated_smiles).int()
      mask_element = (masked_smiles == MSK_ID)
      notmask_element = (masked_smiles != MSK_ID).int()
      notpad_element = (masked_smiles != PAD_ID).int()
      notmp_element = torch.mul(notmask_element, notpad_element)
      num_mask += torch.sum(mask_element)
      num_notmp += torch.sum(notmp_element)
      num_success_mask += torch.sum(torch.mul(equal_element, mask_element))
      num_success_notmp += torch.sum(torch.mul(equal_element, notmp_element))

      del canonical_smiles
      del masked_smiles
      del target_smiles
      torch.cuda.empty_cache()

    mean_loss = total_loss/sample_num
    acc_mask = (100 * num_success_mask / num_mask).item()
    acc_notmp  = (100 * num_success_notmp / num_notmp).item()
    print('%4d'%epoch, '  %6.4f'%mean_loss,  '  %6.4f'%acc_mask, '  %6.4f'%acc_notmp, end='  ')

    list_train_loss.append(mean_loss)
    list_train_acc.append(acc_mask)

    if epoch % 50 == 0:      
      torch.save(model.smiles_encoder1.state_dict(), 'modelsave/'+pf_smiles_en+'-epoch.'+str(epoch))
      torch.save(model.symbol_encoder.state_dict(),  'modelsave/'+pf_symbol_en+'-epoch.'+str(epoch))

    #--------------- test step ------------------
    sample_num = 0
    total_loss = 0
    num_mask = 0
    num_notmp = 0
    num_success_mask = 0
    num_success_notmp = 0

    model.eval()

    for idxs in test_dataloader:
      canonical_smiles = pad_sequence([torch.tensor([BOS_ID]+canonical_smiles_list[idx]+[EOS_ID]) for idx in idxs], batch_first=True).to(device)
      target_smiles = pad_sequence([torch.tensor([BOS_ID]+smiles_list[idx]+[EOS_ID]) for idx in idxs], batch_first=True).to(device)
      masked_smiles = pad_sequence([torch.tensor([BOS_ID]+masking(smiles_list[idx])+[EOS_ID]) for idx in idxs], batch_first=True).to(device)

      estimated_smiles_prob = model(canonical_smiles, masked_smiles, flag) # (batch, len_seq, num_tokens)

      k1 = estimated_smiles_prob.size(0)*estimated_smiles_prob.size(1)
      k2 = target_smiles.size(0)*target_smiles.size(1)
      if k1 != k2:
        print('?????')
        exit()
      
      loss = criterion(estimated_smiles_prob.view(k1,-1), target_smiles.view(k2)) 

      sample_num += len(idxs)
      total_loss += loss.item()*len(idxs)  

      # count the number of symbols recovery succcessy 
      estimated_smiles = torch.argmax(estimated_smiles_prob, dim=2)
      equal_element = (target_smiles == estimated_smiles).int()
      mask_element = (masked_smiles == MSK_ID)
      notmask_element = (masked_smiles != MSK_ID).int()
      notpad_element = (masked_smiles != PAD_ID).int()
      notmp_element = torch.mul(notmask_element, notpad_element) 
      num_mask += torch.sum(mask_element)
      num_notmp += torch.sum(notmp_element)
      num_success_mask += torch.sum(torch.mul(equal_element, mask_element))
      num_success_notmp += torch.sum(torch.mul(equal_element, notmp_element))

      del canonical_smiles
      del masked_smiles
      del target_smiles
      torch.cuda.empty_cache()

    mean_loss = total_loss/sample_num
    acc_mask = (100 * num_success_mask / num_mask).item()
    acc_notmp  = (100 * num_success_notmp / num_notmp).item()
    print('%4d'%epoch, '  %6.4f'%mean_loss,  '  %6.4f'%acc_mask, '  %6.4f'%acc_notmp)

    '''
    print(buf0)
    print(buf1)
    print(buf2)
    print(buf3)
    print('')
    print(smiles2smiles_str(canonical_smiles[0]))
    print(smiles2smiles_str(target_smiles[0]))
    print(smiles2smiles_str(masked_smiles[0]))
    print(vec2string(estimated_smiles_prob[0]))
    '''

    list_test_loss.append(mean_loss)
    list_test_acc.append(acc_mask)


  # save model's parameters 
  torch.save(model.smiles_encoder1.state_dict(), pf_smiles_en+'-epoch.'+str(NumEpoch))
  torch.save(model.symbol_encoder.state_dict(),  pf_symbol_en+'-epoch.'+str(NumEpoch))

  %matplotlib inline
  Figure = plt.figure(figsize=(10,10))
  ax1 = Figure.add_subplot(2,2,1)
  ax2 = Figure.add_subplot(2,2,2)
  ax3 = Figure.add_subplot(2,2,3)
  ax4 = Figure.add_subplot(2,2,4)

  ax1.plot(list_train_loss)
  ax2.plot(list_train_acc)
  ax3.plot(list_test_loss)
  ax4.plot(list_test_acc)

  del model
  del loss

  torch.cuda.empty_cache()

# conduct experiment

In [None]:
torch.cuda.empty_cache()

# hyperparameter setting
DimEmbed = 256     # dimensional of embedding 
DimTfHidden = 256  # dimensional of FNN in TransformerEncoder 
NumHead = 16       # the number of multi-head in TransformerEncoder 
NumLayers = 10     # the number of layers in 
NormFirst = True   
Activation = 'gelu' 
MaskRate = 0.5

# 学習のパラメタ
LearningRate = 0.0005 
NumEpoch = 800
Dropout = 0.1 
batch_size = 128
InitRange = 0.1

TotalSize = 100000

main_program(True)