<a href="https://colab.research.google.com/github/yingzibu/drug_design_JAK/blob/main/conditionalVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Conditional VAE

https://github.com/unnir/cVAE/blob/master/cvae.py


https://github.com/chendaichao/VAE-pytorch

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
cd /content/drive/MyDrive/A_JAK_design

/content/drive/MyDrive/A_JAK_design


In [10]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

if torch.cuda.is_available():
    print('use GPU')
    device = 'cuda'
else:
    print('use CPU')
    device = 'cpu'
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

# from help_function.jak_dataset import *
# from help_function.function import *
import os
def create_path(path):
    # Check whether the specified path exists or not
    isExist = os.path.exists(path)
    #printing if the path exists or not
    print(path, ' folder is in directory: ', isExist)
    if not isExist:
    # Create a new directory because it does not exist
        os.makedirs(path)
        print(path, " is created!")

create_path('../Data')
create_path('../Data/Figures')


use GPU
../Data  folder is in directory:  True
../Data/Figures  folder is in directory:  False
../Data/Figures  is created!


In [11]:
def one_hot(labels, class_size):
    targets = torch.zeros(labels.size(0), class_size)
    for i, label in enumerate(labels):
        targets[i, label] = 1
    return targets.to(device)

class CVAE(nn.Module):
    def __init__(self, feature_size, latent_size, class_size):
        super(CVAE, self).__init__()
        self.feature_size = feature_size
        self.latent_size = latent_size
        self.class_size = class_size

        # encode
        self.fc1 = nn.Linear(feature_size + class_size, 400)
        self.fc21 = nn.Linear(400, latent_size)
        self.fc22 = nn.Linear(400, latent_size)

        # decode
        self.fc3 = nn.Linear(latent_size + class_size, 400)
        self.fc4 = nn.Linear(400, feature_size)

        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()
    def encode(self, x, c): # Q(z|x, c)
        # bs: batch_size
        # x:  (bs, feature_size)
        # c:  (bs, class_size)

        # print('x shape: ', x.shape)
        # x shape:  torch.Size([64, 784])

        inputs = torch.cat([x, c], 1)
        # (bs, feature_size + class_size)， （64， 784 + 10）

        # print('inputs shape: ', inputs.shape)
        # inputs shape:  torch.Size([64, 794])

        h1 = self.elu(self.fc1(inputs))
        z_mu = self.fc21(h1)
        z_var = self.fc22(h1)
        return z_mu, z_var

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c): # P(x|z, c)
        # z: (bs, latent_size)
        # c: (bs, class_size)
        inputs = torch.cat([z, c], 1) # (bs, latent_size + class_size)
        h3 = self.elu(self.fc3(inputs))
        h4 = self.sigmoid(self.fc4(h3))
        return h4

    def forward(self, x, c):

        # print('before view x shape: ', x.shape)
        # before view x shape:  torch.Size([64, 1, 28, 28])

        mu, logvar = self.encode(x.view(-1, 28*28), c)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z, c)
        return recon_x, mu, logvar

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 28*28), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [12]:
def train(epoch):
    model.train()
    train_loss = 0
    for idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        label = one_hot(label, 10)

        # print('After one hot: ', label.shape) #   torch.Size([64, 10]

        recon_batch, mu, logvar = model(data, label)
        optimizer.zero_grad()
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.detach().cpu().numpy()
        optimizer.step()

        if idx % 200 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, idx * len(data), len(train_loader.dataset),
                100. * idx / len(train_loader), loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, labels) in enumerate(test_loader):
            data, labels = data.to(device), labels.to(device)
            labels = one_hot(labels, 10)
            recon_batch, mu, logvar = model(data, labels)
            test_loss += loss_function(
                recon_batch, data, mu, logvar).detach().cpu().numpy()
            if i == 0:
                n = min(data.size(0), 5)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(-1, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))



In [14]:
kwargs = {'num_workers':1, 'pin_memory': True}
batch_size = 64
latent_size = 20
epochs = 10
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../Data', train=True, download=True,
                        transform=transforms.ToTensor()),
        batch_size = batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../Data', train=False, download=True,
                        transform=transforms.ToTensor()),
        batch_size = batch_size, shuffle=True, **kwargs)

# feature_size = 28 * 28, a small figure
# class_size = 10, digits 1-10
model = CVAE(28*28, latent_size, 10).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, epochs + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            c = torch.eye(10, 10).cuda()
            sample = torch.randn(10, 20).to(device)
            sample = model.decode(sample, c).cpu()
            save_image(sample.view(10, 1, 28, 28),
                       'sample_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 143.2693
====> Test set loss: 120.1011
====> Epoch: 2 Average loss: 116.9275
====> Test set loss: 112.6314
====> Epoch: 3 Average loss: 111.6837
====> Test set loss: 108.9167
====> Epoch: 4 Average loss: 108.5511
====> Test set loss: 106.3975
====> Epoch: 5 Average loss: 106.5775
====> Test set loss: 105.1081
====> Epoch: 6 Average loss: 105.2299
====> Test set loss: 104.0493
====> Epoch: 7 Average loss: 104.2675
====> Test set loss: 103.3581
====> Epoch: 8 Average loss: 103.5014
====> Test set loss: 102.7383
====> Epoch: 9 Average loss: 102.9101
====> Test set loss: 102.1490
====> Epoch: 10 Average loss: 102.4457
====> Test set loss: 101.6548


In [3]:
# !pip install pubchempy --quiet
# !pip install transformers --quiet
# !pip install cairosvg --quiet
# !pip install varname --quiet
# !pip install Cython --quiet
# !pip install rdkit --quiet
# # !pip install molsets --quiet
# !pip install pathlib --quiet
# !pip install xgboost==1.6.1 --quiet
# !pip install dgllife --quiet
# !pip install molvs --quiet

# # !pip install dgl==1.1 --quiet # cpu version, usable for calculation

# !pip uninstall dgl -y # dgl cuda version for training using gpu
# !pip install  dgl -f https://data.dgl.ai/wheels/cu118/repo.html --quiet
# !pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html --quiet

# !python --version
# import torch
# print('torch version: ', torch.__version__)
# print('cuda available: ', torch.cuda.is_available())
# import dgl
# print('dgl version: ', dgl.__version__)
# import dgllife
# print('dgllife version: ', dgllife.__version__)
# import rdkit
# print('rdkit version: ', rdkit.__version__)
# import molvs
# print('molvs version: ', molvs.__version__)
# import matplotlib
# print('matplotlib version: ', matplotlib.__version__)

In [15]:
# import pandas as pd
# import numpy as np
# import matplotlib.pyplot as plt
# import torch
# import torch.utils.data
# from torch.utils.data import DataLoader
# from dgllife.model import model_zoo
# from dgllife.utils import smiles_to_bigraph
# from dgllife.utils import AttentiveFPAtomFeaturizer
# from dgllife.utils import AttentiveFPBondFeaturizer
# from dgllife.data import MoleculeCSVDataset
# import dgl
# import matplotlib
# import matplotlib.cm as cm
# from IPython.display import SVG, display
# from rdkit import Chem
# from rdkit.Chem import rdDepictor
# from rdkit.Chem.Draw import rdMolDraw2D