In [None]:
import sys
sys.path.append("..")

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm
from scipy.stats import percentileofscore

import torch
from sklearn.gaussian_process import GaussianProcessRegressor

from moses.vae import VAE
from moses.vae_property import VAEPROPERTY
from moses.vae.trainer import VAETrainer
from moses.vae_property.trainer import VAEPROPERTYTrainer 

from moses.metrics import QED, SA, logP
from moses.utils import get_mol


from rdkit import Chem
from rdkit.Chem import PandasTools
from rdkit import rdBase
#from rdkit import RDLogger

from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler

rdBase.DisableLog('rdApp.*')

import selfies as sf


from bayes_opt import BayesianOptimization
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C

## GPR

In [None]:
train_df = pd.read_csv("../checkpoints/opimize_gpr/gpr_fit_ZINC250K_df.csv")[:1000]
test_df = pd.read_csv("../checkpoints/opimize_gpr/gpr_test_ZINC250K_df.csv")
start_df = pd.read_csv("../checkpoints/opimize_gpr/opt_start_ZINC250K_df.csv")

In [None]:
nan_qed = -100
nan_sa = 100

In [None]:
print(f'gpr train: {train_df.shape}')
print(f'gpr test: {test_df.shape}')
print(f'gpr start: {start_df.shape}')

### Choose model

In [None]:
data_type = 'smiles'

model_name = 'VAEProp_obj_w0.1'
folder_path = f"../checkpoints/ZINC250K_vae_property_obj_proploss_w0.1_{data_type}"
config = torch.load(f'{folder_path}/vae_property_config.pt')
vocab = torch.load(f'{folder_path}/vae_property_vocab.pt')

print(f"Use Selfies: {config.use_selfies}")
print(config.reg_prop_tasks)

In [None]:
cols = ['SELFIES' if config.use_selfies else 'SMILES', 'logP', 'qed', 'SAS', 'obj']
train_data = train_df[cols].values
test_data = test_df[cols].values
start_data = start_df[cols].values


model_path = f'{folder_path}/vae_property_model.pt'

model = VAEPROPERTY(vocab, config)
model.load_state_dict(torch.load(model_path))

trainer = VAEPROPERTYTrainer(config)
train_loader = trainer.get_dataloader(model, train_data, shuffle=False)
test_loader = trainer.get_dataloader(model, test_data, shuffle=False)
start_loader = trainer.get_dataloader(model, start_data, shuffle=False)

In [None]:
model.eval()

x_list = []
z_list = []
mu_list = []
logvar_list = []
y_list = []


# y_list = y_list.squeeze()

for step, batch in enumerate(train_loader):
    x = batch[0]
    y = batch[1]
    x_list.extend(x)
    y_list.extend(np.array(y).squeeze())

    mu, logvar, z, _ = model.forward_encoder(x)
    z_list.extend(z.detach().cpu().numpy())
    mu_list.extend(mu.detach().cpu().numpy())
    logvar_list.extend(logvar.detach().cpu().numpy())



y_list = np.array(y_list).squeeze()
GP_train_y = y_list.reshape(-1, y_list.shape[-1])

train_data_df = pd.DataFrame(GP_train_y, columns=['logP', 'qed', 'SAS', 'obj'])
train_data_df = pd.concat([train_data_df , pd.DataFrame({'z': z_list, 'mu': mu_list, 'logvar': logvar_list})], axis=1)
train_data_df.insert(0, 'SELFIES' if config.use_selfies else 'SMILES', [vocab.ids2string(point.cpu().detach().numpy()) for point in x_list])

In [None]:
model.eval()

test_x_list = []
test_z_list = []
test_mu_list = []
test_logvar_list = []
test_y_list = []


# y_list = y_list.squeeze()

for step, batch in enumerate(test_loader):
    x = batch[0]
    y = batch[1]
    test_x_list.extend(x)
    test_y_list.extend(np.array(y).squeeze())

    mu, logvar, z, _ = model.forward_encoder(x)
    test_z_list.extend(z.detach().cpu().numpy())
    test_mu_list.extend(mu.detach().cpu().numpy())
    test_logvar_list.extend(logvar.detach().cpu().numpy())


test_y_list = np.array(test_y_list).squeeze()
GP_test_y = test_y_list.reshape(-1, test_y_list.shape[-1])

test_data_df = pd.DataFrame(GP_test_y, columns=['logP', 'qed', 'SAS', 'obj'])
test_data_df = pd.concat([test_data_df , pd.DataFrame({'z': test_z_list, 'mu': test_mu_list, 'logvar': test_logvar_list})], axis=1)
test_data_df
# test_data_df.insert(0, 'SELFIES' if config.use_selfies else 'SMILES', [vocab.ids2string(point.cpu().detach().numpy()) for point in test_x_list])

In [None]:
GP_Train_x = torch.tensor(np.array([x for x in train_data_df['z']]))
GP_Test_x = torch.tensor(np.array([x for x in test_data_df['z']]))

GP_Train_y = np.array([x for x in train_data_df['obj']])
GP_Test_y = np.array([x for x in test_data_df['obj']])

In [None]:
gen = model.sample(len(GP_Train_x), max_len=100, z=GP_Train_x, temp=1.0, test=True)
gen_df = pd.DataFrame(gen, columns=['gen_SELFIES' if config.use_selfies else 'gen_SMILES'])

if config.use_selfies:
    gen_df['gen_SMILES'] = [sf.decoder(x) for x in gen_df['gen_SELFIES']]
    mol = gen_df['gen_SMILES'].apply(Chem.MolFromSmiles)
else:
    mol = gen_df['gen_SMILES'].apply(Chem.MolFromSmiles)

qed_list = []
sa_list = []
null_cnt = 0

for i, gen_mol in enumerate(mol):
    if gen_mol is None:
        qed_list.append(nan_qed)
        sa_list.append(nan_sa)
        null_cnt += 1
        
    else:
        qed = QED(gen_mol)
        sa = SA(gen_mol)
        qed_list.append(qed)
        sa_list.append(sa)
        
gen_df['gen_qed'] = qed_list
gen_df['gen_sa'] = sa_list

In [None]:
print(f"Null SMILES: {null_cnt}")
gen_df.gen_SMILES.unique()

In [None]:
import torch
import gpytorch
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from botorch.acquisition import ExpectedImprovement
from botorch.optim import optimize_acqf


def calculate_qed_sa(z):
    gen = model.sample(len(z), max_len=100, z=z, temp=1.0, test=True)
    gen_df = pd.DataFrame(gen, columns=['gen_SELFIES' if config.use_selfies else 'gen_SMILES'])

    if config.use_selfies:
        gen_df['gen_SMILES'] = [sf.decoder(x) for x in gen_df['gen_SELFIES']]
        mol = gen_df['gen_SMILES'].apply(Chem.MolFromSmiles)
    else:
        mol = gen_df['gen_SMILES'].apply(Chem.MolFromSmiles)

    qed_list = []
    sa_list = []

    for i, gen_mol in enumerate(mol):
        if gen_mol is None:
            qed_list.append(nan_qed)
            sa_list.append(nan_sa)
            
        else:
            qed = QED(gen_mol)
            sa = SA(gen_mol)
            qed_list.append(qed)
            sa_list.append(sa)
            
    gen_df['gen_qed'] = qed_list
    gen_df['gen_sa'] = sa_list

    return gen_df['gen_qed'].values, gen_df['gen_sa'].values

# 목적 함수 정의
def objective_function(z):
    qed, sa = calculate_qed_sa(z)
    return 5 * qed - sa

# 초기 데이터 수집
def initial_data():
    train_z = GP_Train_x
    train_y = torch.tensor(5*gen_df['gen_qed'] - gen_df['gen_sa'])
    return train_z, train_y.unsqueeze(-1)

# GPR 모델 학습
def train_gp(train_z, train_y):
    gp = SingleTaskGP(train_z, train_y)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(gp.likelihood, gp)
    fit_gpytorch_model(mll)
    return gp

# 획득 함수 최적화
def optimize_acq(gp, bounds):
    acqf = ExpectedImprovement(gp, best_f=train_y.max().item())
    new_z, _ = optimize_acqf(acqf, bounds=bounds, q=1, num_restarts=5, raw_samples=20)
    
    return new_z

In [None]:
# 초기 데이터
train_z, train_y = initial_data()

# 초기 설정
bounds = torch.stack([torch.full((z.shape[1],), min(train_z.reshape(-1))), torch.full((z.shape[1],), max(train_z.reshape(-1)))]) # to-do : z의 shape에 맞게 수정

print(train_z.shape)
print(train_y.shape)

num_iterations = 50
new_z_list = []
all_z_list = []

best_perform = -np.inf

for iter in range(num_iterations):
    gp = train_gp(train_z, train_y)
    new_z = optimize_acq(gp, bounds)

    new_y = torch.tensor([objective_function(new_z)])
    
    if new_y > best_perform:
        best_perform = new_y
        new_z_list.append(new_z)
    
    all_z_list.append(new_z)
    
    # 데이터 업데이트
    train_z = torch.cat((train_z, new_z), dim=0)
    train_y = torch.cat((train_y, new_y), dim=0)
    
print(f"최적의 z index {train_y.argmax()}:", train_z[train_y.argmax()])
print("최적의 목적 함수 값:", train_y.max().item())

In [None]:
new_z_list = np.array(new_z_list).squeeze()
new_z_list.shape

all_z_list = np.array(all_z_list).squeeze()
all_z_list.shape

In [None]:
gen = model.sample(len(new_z_list), max_len=100, z=torch.tensor(new_z_list), temp=1.0, test=True)
gen_df = pd.DataFrame(gen, columns=['gen_SELFIES' if config.use_selfies else 'gen_SMILES'])

if config.use_selfies:
    gen_df['gen_SMILES'] = [sf.decoder(x) for x in gen_df['gen_SELFIES']]
    mol = gen_df['gen_SMILES'].apply(Chem.MolFromSmiles)
else:
    mol = gen_df['gen_SMILES'].apply(Chem.MolFromSmiles)

qed_list = []
sa_list = []

for i, gen_mol in enumerate(mol):
    if gen_mol is None:
        qed_list.append(0)
        sa_list.append(7)
        
    else:
        qed = QED(gen_mol)
        sa = SA(gen_mol)
        qed_list.append(qed)
        sa_list.append(sa)
        
gen_df['gen_qed'] = qed_list
gen_df['gen_sa'] = sa_list

In [None]:
gen_df['obj'] = 5*gen_df['gen_qed'] - gen_df['gen_sa']

In [None]:
gen_df

In [None]:
plt.plot(range(len(gen_df)), gen_df.obj.values)
plt.show()

In [None]:
gen_df['RoMol'] = gen_df['gen_SMILES'].apply(Chem.MolFromSmiles)
if data_type == 'selfies':
    display(PandasTools.FrameToGridImage(gen_df, column='RoMol', legendsCol='gen_SELFIES', molsPerRow=4))
else:
    display(PandasTools.FrameToGridImage(gen_df, column='RoMol', legendsCol='gen_SMILES', molsPerRow=4))

In [None]:
gen = model.sample(len(all_z_list), max_len=100, z=torch.tensor(all_z_list), temp=1.0, test=True)
gen_df = pd.DataFrame(gen, columns=['gen_SELFIES' if config.use_selfies else 'gen_SMILES'])

if config.use_selfies:
    gen_df['gen_SMILES'] = [sf.decoder(x) for x in gen_df['gen_SELFIES']]
    mol = gen_df['gen_SMILES'].apply(Chem.MolFromSmiles)
else:
    mol = gen_df['gen_SMILES'].apply(Chem.MolFromSmiles)

qed_list = []
sa_list = []

for i, gen_mol in enumerate(mol):
    if gen_mol is None:
        print(f'Error: {gen_df["gen_SMILES"].iloc[i]}')
        pass
        
    else:
        qed = QED(gen_mol)
        sa = SA(gen_mol)
        qed_list.append(qed)
        sa_list.append(sa)
        
gen_df['gen_qed'] = qed_list
gen_df['gen_sa'] = sa_list

In [None]:
gen_df['obj'] = 5*gen_df['gen_qed'] - gen_df['gen_sa']

In [None]:
gen_df

In [None]:
plt.plot(range(len(gen_df)), gen_df.obj.values)
plt.show()

In [None]:
gen_df['RoMol'] = gen_df['gen_SMILES'].apply(Chem.MolFromSmiles)
display(PandasTools.FrameToGridImage(gen_df, column='RoMol', legendsCol='gen_SELFIES', molsPerRow=5))

## Latent Vector Interpolation

In [None]:
import torch
import numpy as np
import pandas as pd
from viz_utils import slerp, InterpolationLoader, z_to_smiles

In [None]:
model_type = 'vae_property' # 'vae_property', 'vae'
data_type = 'selfies'  # 'selfies'
# data_type = 'smiles'
steps = 4
epoch = 60
sample_1 = 3
sample_2 = 5

In [None]:
z_list, y_list, _, train_data, model = InterpolationLoader(dataPATH="../moses/dataset/data/ZINC250K/",
                                                    model_type=model_type,
                                                    data_type=data_type,
                                                    best_epoch=epoch,
                                                    i_1=sample_1, i_2=sample_2,
                                                    )

original_mol = train_data[:,0]

interpolated_latents = torch.tensor(np.array([slerp(val, z_list[0,:], z_list[1,:]) for val in np.linspace(0, 1, steps)]))
viz_df = z_to_smiles(model, original_mol, interpolated_latents,
                     data_type=data_type, steps=steps,
                     temp=0.3, argmax=False)

In [None]:
if data_type == 'smiles':
    result_mol = viz_df['SMILES'].values
else:
    result_mol = viz_df["SELFIES"].values
    
result_mol

## Latent Space Vizualization

In [None]:
data_type = 'smiles' # 'selfies'
# data_type = 'selfies' # 'selfies'

In [None]:
train_df = pd.read_csv("../moses/dataset/data/ZINC250K/train.csv")
test_df = pd.read_csv("../moses/dataset/data/ZINC250K/test.csv")

In [None]:
if data_type == 'selfies':
    folder_path = "../checkpoints/ZINC250K_vae_property_obj_proploss_w0.1_selfies"
else:
    folder_path = "../checkpoints/ZINC250K_vae_property_obj_proploss_w0.1_smiles"

    
config = torch.load(f'{folder_path}/vae_property_config.pt')
vocab = torch.load(f'{folder_path}/vae_property_vocab.pt')

print(f"Use Selfies: {config.use_selfies}")
print(config.reg_prop_tasks)

cols = ['SELFIES' if config.use_selfies else 'SMILES', 'logP', 'qed', 'SAS', 'obj']
train_data = train_df[cols].values
test_data = test_df[cols].values

model_path = f'{folder_path}/vae_property_model_080.pt'

model = VAEPROPERTY(vocab, config)
model.load_state_dict(torch.load(model_path))

trainer = VAEPROPERTYTrainer(config)
train_loader = trainer.get_dataloader(model, train_data, shuffle=False)
test_loader = trainer.get_dataloader(model, test_data, shuffle=False)

In [None]:
fig, axes = plt.subplots(1, 6, figsize=(30, 4))

for i, epoch in enumerate(['00', 20, 40, 60, 80, 'final']):
    
    model_path = f'{folder_path}/vae_property_model_0{epoch}.pt'
    
    if epoch == 'final':
        model_path = f'{folder_path}/vae_property_model.pt'
        
    model = VAEPROPERTY(vocab, config)
    model.load_state_dict(torch.load(model_path))
    
    model.eval()

    x_list = []
    z_list = []
    mu_list = []
    logvar_list = []
    y_list = []

    for step, batch in enumerate(train_loader):
        x = batch[0]
        y = batch[1]
        x_list.extend(x)
        y_list.extend(np.array(y).squeeze())

        mu, logvar, z, _ = model.forward_encoder(x)
        z_list.extend(z.detach().cpu().numpy())
        mu_list.extend(mu.detach().cpu().numpy())
        logvar_list.extend(logvar.detach().cpu().numpy())

    viz = PCA(n_components=2)
    z_viz = viz.fit_transform(mu_list)
    explained_variance = viz.explained_variance_ratio_
    print(f"(Epoch {epoch})Explained variance: {explained_variance}")
    
    y_list = np.array(y_list)[:, -1]
    
    # print(z_viz.shape)
    z_viz = MinMaxScaler().fit_transform(z_viz)

    scatter = axes[i].scatter(z_viz[:, 0], z_viz[:, 1], c=y_list, cmap='viridis', marker='.', s=10, alpha=0.5, edgecolors='none')

    axes[i].set_title(f'Epoch {epoch}')
    axes[i].set_xlabel('PC1')
    axes[i].set_ylabel('PC2')
    
    fig.colorbar(scatter, ax=axes[i])
    
plt.tight_layout()
plt.show()