In [None]:
import numpy as np
import pandas as pd
import torch


from rdkit import Chem, RDLogger
RDLogger.DisableLog('rdApp.*')
from rdkit.Chem import PandasTools


from collections import Counter
from itertools import product
import zipfile
from io import BytesIO

import selfies as sf

from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

from tqdm import tqdm

In [None]:
import sys
sys.path.append("..")
import moses
from moses.vae import VAE
from moses.vae_property import VAEPROPERTY
from moses.utils import CharVocab, StringDataset, SELFIESVocab
from moses.vae.trainer import VAETrainer
from moses.vae_property.trainer import VAEPROPERTYTrainer 

from moses.metrics import QED, SA, logP
from moses.utils import get_mol

# 1. Loading data

## 1.1 Loading model

In [None]:
folder_path = "../model_results/ZINC250K_vae_property_obj_ws"

config_path = "vae_property_config.pt"
model_path = "vae_property_model_080.pt"

config = torch.load(f"{folder_path}/{config_path}")

vocab_path = "vae_property_vocab.pt"
vocab = torch.load(f"{folder_path}/{vocab_path}")

model = VAEPROPERTY(vocab, config)
model.load_state_dict(torch.load(f"{folder_path}/{model_path}"))



## 1.2 Loading train data

In [None]:
data_folder_path = "../moses/dataset/data/ZINC250K"

file_name = "train.csv"

data = pd.read_csv(f"{data_folder_path}/{file_name}")

In [None]:
data.head() #it has smiles, selfies, logP, qed, SAS, obj values

### 1.2.1 generate whole_latent_data

In [None]:
np.random.seed(42)
samples = data.iloc[np.random.choice(data.shape[0], 1000, replace=False)]

In [None]:
def get_latent_info(mol_smiles, model):
    mu, log_var, z, _ = model.forward_encoder(model.string2tensor(mol_smiles).reshape(1,-1))

    return mu, log_var, z

In [None]:
def get_latent_whole_info(data, model):
    mu_list = []
    log_var_list = []
    z_list = []
    for i in range(len(data)):
        mu, log_var, z = get_latent_info(data["SMILES"].iloc[i], model)
        mu_list.append(mu)
        log_var_list.append(log_var)
        z_list.append(z)

    mu_df = pd.DataFrame([[point.detach().cpu().numpy()]for point in mu_list], columns=["mu"], index=data.index)
    logvar_df = pd.DataFrame([[point.detach().cpu().numpy()]for point in log_var_list], columns=["logvar"], index=data.index)
    z_df = pd.DataFrame([[point.detach().cpu().numpy()]for point in z_list], columns=["z"], index=data.index)

    return mu_df, logvar_df, z_df

In [None]:
mu_df, log_var_df, z_df = get_latent_whole_info(samples, model)

In [None]:
whole_latent_info = pd.concat([samples, mu_df, log_var_df, z_df], axis=1)

In [None]:
whole_latent_info

## 1.2.2 for test mol, generate molecular diagram

In [None]:
def sample_latent_space(mu, model, ranges=(3,3), latent_dim=128, n_trials=1000, n_grid=6, temp=0.01):
    #n_grid 개수만큼 상하좌우로 이동.
    np.random.seed(42)

    get_dims = np.random.randint(0, latent_dim, 2)

    dim1, dim2 = get_dims[0], get_dims[1]

    x_range, y_range = ranges

    dx = np.linspace(-x_range, x_range, 2*n_grid+1)
    dy = np.linspace(-y_range, y_range, 2*n_grid+1)

    coords = np.eye(latent_dim) #then, this array contains the unit vector of each dimension

    grid = pd.DataFrame(columns=range(2*n_grid+1), index=range(2*n_grid+1))


    for i_x, i_y in tqdm(product(range(2*n_grid+1), range(2*n_grid+1)), desc='whole_iters', total=(2*n_grid+1)**2): #each data points
        z_point = mu + dx[i_x]*coords[dim1] + dy[i_y]*coords[dim2] # move the point to the direction of the unit vector
        most_freq_mol = decode_z(z_point, model, n_trials, temp)

        grid.iloc[i_x, i_y] = most_freq_mol
    
    return grid

def decode_z(z, model, n_trials, temp):
    # decode n_trials times from z using the model sample function, and pick most frequent one
    z_input = torch.tensor(z).repeat(n_trials).reshape(n_trials, -1)

    z_input = z_input.float()

    decoded_mols = model.sample(n_batch=z_input.shape[0], z=z_input, temp=temp)

    valid_decoded_mols = [Chem.MolFromSmiles(mol) for mol in decoded_mols if Chem.MolFromSmiles(mol) is not None]
    
    print(f"ratio of valid molecules : {len(valid_decoded_mols)}/{n_trials}")

    canon_dec_mols = [Chem.MolToSmiles(Chem.MolFromSmiles(mol), canonical=True) for mol in decoded_mols if Chem.MolFromSmiles(mol) is not None]

    mol_freq = Counter(canon_dec_mols)

    most_freq_mol, _ = find_argmax(mol_freq)

    return most_freq_mol

def find_argmax(counter):
    # Get the most common element and its count
    most_common_element = counter.most_common(1)
    if most_common_element:
        element, count = most_common_element[0]
        return element, count
    else:
        return None, None


def plot_freq_mols(freq_mols_grid, save_name=None):
    n_rows, n_cols = freq_mols_grid.shape

    fig, ax = plt.subplots(n_rows, n_cols, figsize=(10,10))

    for i in range(n_rows):
        for j in range(n_cols):
            mol = Chem.MolFromSmiles(freq_mols_grid.iloc[i,j])
            if mol is not None:
                img = Chem.Draw.MolToImage(mol, size=(300,300))
                ax[i,j].imshow(img)
                ax[i,j].axis("off")
            else:
                ax[i,j].axis("off")

    if save_name is not None:
        plt.savefig(save_name)
        plt.clf()
    

In [None]:
test_mol = whole_latent_info.iloc[0]

In [None]:
test_mol

In [None]:
img = Chem.Draw.MolToImage(Chem.MolFromSmiles(test_mol['SMILES']), size=(500,500))


img

In [None]:
ranges_list = [(1,1), (2,2), (3,3)]

temp_list = [0.01, 0.1, 1]

experiments = list(product(ranges_list, temp_list))

for i, (ranges, temp) in enumerate(experiments):
    print(f"------------- Experiment {i} -------------")
    grid = sample_latent_space(test_mol["mu"][0], model, ranges=ranges, temp=temp, n_trials=1000, n_grid=6)
    plot_freq_mols(grid, save_name=f"figures/test_mol_{i}.png")