In [1]:
import gentrl
import torch
import pandas as pd
torch.cuda.set_device(0)



In [2]:
from moses.metrics import mol_passes_filters, QED, SA, logP
from moses.metrics.utils import get_n_rings, get_mol


def get_num_rings_6(mol):
    r = mol.GetRingInfo()
    return len([x for x in r.AtomRings() if len(x) > 6])


def penalized_logP(mol_or_smiles, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward

In [None]:
! wget https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv

In [3]:
! ls -al

total 427072
drwxr-xr-x 1 root root     4096  2월 10 18:28 .
drwxr-xr-x 1 root root     4096  2월  5 19:06 ..
drwxr-xr-x 2 root root     4096  2월 10 12:16 .ipynb_checkpoints
-rw-r--r-- 1 root root     2574  2월  7 20:32 Untitled.ipynb
-rw-r--r-- 1 root root      885  2월  9 23:18 Untitled1.ipynb
-rw-r--r-- 1 root root 84482588  2월  7 17:02 dataset_v1.csv
-rw-r--r-- 1 root root 84482588  2월 10 09:57 dataset_v1.csv.1
-rw-r--r-- 1 root root 84482588  2월 10 12:16 dataset_v1.csv.2
-rw-r--r-- 1 root root 84482588  2월 10 18:26 dataset_v1.csv.3
-rw-r--r-- 1 root root    16292  2월 10 18:28 pretrain-Copy1.ipynb
-rwxr-xr-x 1 root root     9215  2월 10 18:25 pretrain.ipynb
-rwxr-xr-x 1 root root   312729  2월 10 18:24 sampling.ipynb
drwxr-xr-x 2 root root     4096  2월  7 17:03 saved_gentrl
drwxr-xr-x 2 root root     4096  2월  7 20:00 saved_gentrl_after_rl
-rw-r--r-- 1 root root 98895256  2월 10 11:16 train_plogp_plogpm.csv
-rw-r--r-- 1 root root    50919  2월 10 16:51 train_plogp_plogpm_te

In [13]:
df = pd.read_csv('dataset_v1.csv')
tempIndex = len(df)//100
df = df[0:tempIndex]
df = df[df['SPLIT'] == 'train']
df['plogP'] = df['SMILES'].apply(penalized_logP)
df.to_csv('train_plogp_plogpm.csv', index=None)
df.head()

Unnamed: 0,SMILES,SPLIT,plogP
0,CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1,train,-2.131918
1,CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1,train,0.792973
3,Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO,train,-0.366775
4,Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C,train,-5.202261
5,CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O,train,-2.24865


In [3]:
import torch
from torch import nn
from gentrl.tokenizer import encode, get_vocab_size

class CNNEncoder(nn.Module):
    def __init__(self, hidden_size=256, latent_size=50):
        super(CNNEncoder, self).__init__()

        self.embs = nn.Embedding(get_vocab_size(), hidden_size)
        self.cnn = nn.Conv1d(50,50,1)

        self.final_mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size), nn.LeakyReLU(),
            nn.Linear(hidden_size, 2 * latent_size))

    def encode(self, sm_list):

        tokens, lens = encode(sm_list)
        to_feed = tokens.transpose(1, 0).to(self.embs.weight.device)

        outputs = self.cnn(self.embs(to_feed))
        outputs = self.cnn(outputs)

        outputs = outputs[lens, torch.arange(len(lens))]

        return self.final_mlp(outputs)


In [4]:
enc = CNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)
model.cuda();

In [5]:
md = gentrl.MolecularDataset(sources=[
    {'path':'train_plogp_plogpm.csv',
     'smiles': 'SMILES',
     'prob': 1,
     'plogP' : 'plogP',
    }], 
    props=['plogP'])

from torch.utils.data import DataLoader
train_loader = DataLoader(md, batch_size=50, shuffle=True, num_workers=1, drop_last=True)

In [None]:
# model.train_as_vaelp(train_loader, lr=1e-4)

In [6]:
model.train_as_vaelp(train_loader, 3, verbose_step=5, lr=1e-3)

Epoch 0 :
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!loss: 2.847;rec: -2.737;kl: -61.7;log_p_y_by_z: -1.718;log_p_z_by_y: -78.55;
!!!!!loss: 2.157;rec: -2.065;kl: -63.31;log_p_y_by_z: -1.548;log_p_z_by_y: -76.11;
!!!!!loss: 1.908;rec: -1.817;kl: -63.19;log_p_y_by_z: -1.54;log_p_z_by_y: -74.45;
!!!!!loss: 1.724;rec: -1.63;kl: -62.88;log_p_y_by_z: -1.56;log_p_z_by_y: -73.14;
!!!!!loss: 1.617;rec: -1.532;kl: -62.85;log_p_y_by_z: -1.48;log_p_z_by_y: -71.93;
!!!!!loss: 1.522;rec: -1.43;kl: -63.22;log_p_y_by_z: -1.553;log_p_z_by_y: -70.64;
!!!!!loss: 1.456;rec: -1.374;kl: -62.67;log_p_y_by_z: -1.449;log_p_z_by_y: -70.43;
!!!!!loss: 1.37;rec: -1.285;kl: -63.21;log_p_y_by_z: -1.479;log_p_z_by_y: -69.48;
!!!!!loss: 1.377;rec: -1.285;kl: -62.97;log_p_y_by_z: -1.555;log_p_z_by_y: -69.52;
!!!!!loss: 1.271;rec: -1.189;kl: -63.39;log_p_y_by_z: -1.457;log_p_z_by_y: -69.19;
!!!!!loss: 1.285;rec: -1.19;kl: -63.58;log_p_y_by_z: -1.58;log_p_z_

<gentrl.gentrl.TrainStats at 0x7f5ca1f4be48>

In [8]:
! mkdir -p saved_gentrlCNN

In [9]:
model.save('./saved_gentrlCNN/')

In [None]:
# ! ls -la ./saved_gentrl/