In [1]:
import gentrl
import torch
torch.cuda.set_device(0)



In [2]:
import torch
from torch import nn
from gentrl.tokenizer import encode, get_vocab_size

class CNNEncoder(nn.Module):
    def __init__(self, hidden_size=256, latent_size=50):
        super(CNNEncoder, self).__init__()

        self.embs = nn.Embedding(get_vocab_size(), hidden_size)
        self.cnn = nn.Conv1d(50,50,1)

        self.final_mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size), nn.LeakyReLU(),
            nn.Linear(hidden_size, 2 * latent_size))

    def encode(self, sm_list):

        tokens, lens = encode(sm_list)
        to_feed = tokens.transpose(1, 0).to(self.embs.weight.device)

        outputs = self.cnn(self.embs(to_feed))
        outputs = self.cnn(outputs)

        outputs = outputs[lens, torch.arange(len(lens))]

        return self.final_mlp(outputs)

In [3]:
enc = CNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)
model.cuda();

In [4]:
enc

CNNEncoder(
  (embs): Embedding(28, 256)
  (cnn): Conv1d(50, 50, kernel_size=(1,), stride=(1,))
  (final_mlp): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=256, out_features=100, bias=True)
  )
)

In [5]:
dec

DilConvDecoder(
  (latent_fc): Linear(in_features=50, out_features=128, bias=True)
  (input_embeddings): Embedding(28, 128)
  (logits_1x1_layer): Conv1d(128, 28, kernel_size=(1,), stride=(1,))
  (parameters): ParameterList(
      (0): Parameter containing: [torch.cuda.FloatTensor of size 28x128 (GPU 0)]
      (1): Parameter containing: [torch.cuda.FloatTensor of size 28x128x1 (GPU 0)]
      (2): Parameter containing: [torch.cuda.FloatTensor of size 28 (GPU 0)]
      (3): Parameter containing: [torch.cuda.FloatTensor of size 128x50 (GPU 0)]
      (4): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]
      (5): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]
      (6): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]
      (7): Parameter containing: [torch.cuda.FloatTensor of size 128x128x1 (GPU 0)]
      (8): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]
      (9): Parameter containing: [torch.cuda.FloatTensor o

In [6]:
model.load('saved_gentrlCNN/')
model.cuda();

In [7]:
model

GENTRL(
  (enc): CNNEncoder(
    (embs): Embedding(28, 256)
    (cnn): Conv1d(50, 50, kernel_size=(1,), stride=(1,))
    (final_mlp): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=256, out_features=100, bias=True)
    )
  )
  (dec): DilConvDecoder(
    (latent_fc): Linear(in_features=50, out_features=128, bias=True)
    (input_embeddings): Embedding(28, 128)
    (logits_1x1_layer): Conv1d(128, 28, kernel_size=(1,), stride=(1,))
    (parameters): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 28x128 (GPU 0)]
        (1): Parameter containing: [torch.cuda.FloatTensor of size 28x128x1 (GPU 0)]
        (2): Parameter containing: [torch.cuda.FloatTensor of size 28 (GPU 0)]
        (3): Parameter containing: [torch.cuda.FloatTensor of size 128x50 (GPU 0)]
        (4): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]
        (5): Parameter 

In [8]:
from moses.metrics import mol_passes_filters, QED, SA, logP
from moses.metrics.utils import get_n_rings, get_mol

from moses.utils import disable_rdkit_log
disable_rdkit_log()

def get_num_rings_6(mol):
    r = mol.GetRingInfo()
    return len([x for x in r.AtomRings() if len(x) > 6])


def penalized_logP(mol_or_smiles, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward

In [9]:
model.train_as_rl(penalized_logP,10)


!!!!!!!!!!

<gentrl.gentrl.TrainStats at 0x7f7351792748>

In [10]:
! mkdir -p saved_gentrlCNN_after_rl

In [11]:
model.save('./saved_gentrlCNN_after_rl/')