<a href="https://colab.research.google.com/github/yingzibu/drug_design_JAK/blob/main/VAE/Molecular_VAE_others.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/aksub99/molecular-vae.git

Cloning into 'molecular-vae'...
remote: Enumerating objects: 188, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 188 (delta 0), reused 0 (delta 0), pack-reused 185[K
Receiving objects: 100% (188/188), 2.99 MiB | 19.00 MiB/s, done.
Resolving deltas: 100% (95/95), done.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import gzip
import pandas
import h5py
import numpy as np
from __future__ import print_function
import argparse
import os
import h5py
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn import model_selection

In [3]:
def one_hot_array(i, n):
    return map(int, [ix == i for ix in xrange(n)])

def one_hot_index(vec, charset):
    return map(charset.index, vec)

def from_one_hot_array(vec):
    oh = np.where(vec == 1)
    if oh[0].shape == (0, ):
        return None
    return int(oh[0][0])

def decode_smiles_from_indexes(vec, charset):
    return "".join(map(lambda x: charset[x], vec)).strip()

def load_dataset(filename, split = True):
    h5f = h5py.File(filename, 'r')
    if split:
        data_train = h5f['data_train'][:]
    else:
        data_train = None
    data_test = h5f['data_test'][:]
    charset =  h5f['charset'][:]
    h5f.close()
    if split:
        return (data_train, data_test, charset)
    else:
        return (data_test, charset)


In [4]:
class MolecularVAE(nn.Module):
    def __init__(self):
        super(MolecularVAE, self).__init__()

        self.conv_1 = nn.Conv1d(120, 9, kernel_size=9)
        self.conv_2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv_3 = nn.Conv1d(9, 10, kernel_size=11)
        self.linear_0 = nn.Linear(70, 435)
        self.linear_1 = nn.Linear(435, 292)
        self.linear_2 = nn.Linear(435, 292)

        self.linear_3 = nn.Linear(292, 292)
        self.gru = nn.GRU(292, 501, 3, batch_first=True)
        self.linear_4 = nn.Linear(501, 33)

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()

    def encode(self, x):
        x = self.relu(self.conv_1(x))
        x = self.relu(self.conv_2(x))
        x = self.relu(self.conv_3(x))
        x = x.view(x.size(0), -1)
        x = F.selu(self.linear_0(x))
        return self.linear_1(x), self.linear_2(x)

    def sampling(self, z_mean, z_logvar):
        epsilon = 1e-2 * torch.randn_like(z_logvar)
        return torch.exp(0.5 * z_logvar) * epsilon + z_mean

    def decode(self, z):
        z = F.selu(self.linear_3(z))
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 120, 1)
        output, hn = self.gru(z)
        out_reshape = output.contiguous().view(-1, output.size(-1))
        y0 = F.softmax(self.linear_4(out_reshape), dim=1)
        y = y0.contiguous().view(output.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        z_mean, z_logvar = self.encode(x)
        z = self.sampling(z_mean, z_logvar)
        return self.decode(z), z_mean, z_logvar

In [5]:
!rm -R 'molecular-vae'
!git clone https://github.com/aksub99/molecular-vae.git
import zipfile
zip_ref = zipfile.ZipFile('molecular-vae/data/processed.zip', 'r')
zip_ref.extractall('molecular-vae/data/')
zip_ref.close()


Cloning into 'molecular-vae'...
remote: Enumerating objects: 188, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 188 (delta 0), reused 0 (delta 0), pack-reused 185[K
Receiving objects: 100% (188/188), 2.99 MiB | 18.42 MiB/s, done.
Resolving deltas: 100% (95/95), done.


In [7]:
def vae_loss(x_decoded_mean, x, z_mean, z_logvar):
    xent_loss = F.binary_cross_entropy(x_decoded_mean, x, size_average=False)
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    return xent_loss + kl_loss

data_train, data_test, charset = load_dataset('molecular-vae/data/processed.h5')
data_train = torch.utils.data.TensorDataset(torch.from_numpy(data_train))
train_loader = torch.utils.data.DataLoader(data_train, batch_size=250, shuffle=True)

torch.manual_seed(42)

epochs = 30
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MolecularVAE().to(device)
optimizer = optim.Adam(model.parameters())





train tensor(150.5092, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(121.8401, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(119.2447, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(116.9999, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(113.8947, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(112.4204, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(108.3538, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(103.4096, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(100.6264, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(96.4871, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(92.8389, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(90.3792, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(87.2286, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(84.0872, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(82.5107, device='cuda:0', grad_fn=<DivBackward0>)
train tensor(80.5589, device='cuda:0', grad_fn

In [11]:
data_test = torch.utils.data.TensorDataset(torch.from_numpy(data_test))
test_loader = torch.utils.data.DataLoader(data_test, batch_size=250, shuffle=True)


In [None]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in tqdm(enumerate(train_loader),
                                total=len(train_loader)):
        data = data[0].to(device)
        optimizer.zero_grad()
        output, mean, logvar = model(data)

        # if batch_idx==0:
        #       inp = data.cpu().numpy()
        #       outp = output.cpu().detach().numpy()
        #       lab = data.cpu().numpy()
            #   print("Input:")
            #   print(decode_smiles_from_indexes(map(from_one_hot_array, inp[0]), charset))
            #   print("Label:")
            #   print(decode_smiles_from_indexes(map(from_one_hot_array, lab[0]), charset))
            #   sampled = outp[0].reshape(1, 120, len(charset)).argmax(axis=2)[0]
            #   print("Output:")
            #   print(decode_smiles_from_indexes(sampled, charset))

        loss = vae_loss(output, data, mean, logvar)
        loss.backward()
        train_loss += loss
        optimizer.step()
#         if batch_idx % 100 == 0:
#             print(f'{epoch} / {batch_idx}\t{loss:.4f}')
    train_loss_ave = (train_loss / len(train_loader.dataset)).item()
    print(f'epoch: {epoch} \t train: {train_loss_ave:.3f}')
    return train_loss_ave

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    if epoch % 5 == 0:
        model.eval()
        test_loss = 0
        for data in test_loader:
            data = data[0].to(device)
            output, mean, logvar = model(data)
            loss = vae_loss(output, data, mean, logvar)
            test_loss += loss.item()
        test_loss_ave = test_loss/ len(test_loader.dataset)
        print(f'---> test: {test_loss_ave:.3f}')

100%|██████████| 160/160 [00:46<00:00,  3.46it/s]


epoch: 1 	 train: 58.550


100%|██████████| 160/160 [00:48<00:00,  3.29it/s]


epoch: 2 	 train: 58.633


100%|██████████| 160/160 [00:49<00:00,  3.21it/s]


epoch: 3 	 train: 56.752


100%|██████████| 160/160 [00:51<00:00,  3.11it/s]


epoch: 4 	 train: 56.621


100%|██████████| 160/160 [00:51<00:00,  3.08it/s]


epoch: 5 	 train: 55.079
epoch: 5 	 test: 57.013


100%|██████████| 160/160 [00:52<00:00,  3.03it/s]


epoch: 6 	 train: 54.707


100%|██████████| 160/160 [00:53<00:00,  3.00it/s]


epoch: 7 	 train: 54.062


100%|██████████| 160/160 [00:53<00:00,  3.01it/s]


epoch: 8 	 train: 52.772


100%|██████████| 160/160 [00:53<00:00,  2.99it/s]


epoch: 9 	 train: 52.833


100%|██████████| 160/160 [00:53<00:00,  2.99it/s]


epoch: 10 	 train: 51.365
epoch: 10 	 test: 53.268


In [9]:
from tqdm import tqdm