In [None]:
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import PandasTools
import zipfile
from io import BytesIO

import selfies as sf

import sys
sys.path.append("..")
import moses
from moses.vae import VAE
from moses.vae_property import VAEPROPERTY
from moses.utils import CharVocab, StringDataset, SELFIESVocab
from moses.vae.trainer import VAETrainer
from moses.vae_property.trainer import VAEPROPERTYTrainer 

from moses.metrics import QED, SA, logP
from moses.utils import get_mol

from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [None]:
folder_path = "../checkpoints/ZINC_vae_property_20240604_052224"
config = torch.load(f'{folder_path}/vae_property_config.pt')
vocab = torch.load(f'{folder_path}/vae_property_vocab.pt')
model_path = f'{folder_path}/vae_property_model_080.pt'


# train_data = pd.DataFrame(moses.get_dataset('train', config)[:50000], columns=['SMILES', 'logP',])
train_data = moses.get_dataset('train', config)[:50000]

In [None]:
model = VAEPROPERTY(vocab, config)
model.load_state_dict(torch.load(model_path))

In [None]:
trainer = VAEPROPERTYTrainer(config)
sample_loader = trainer.get_dataloader(model, train_data, shuffle=False)

In [None]:
model.eval()

z_list = []
y_list = []
for step, batch in enumerate(sample_loader):
    if len(batch[0]) == 512:
        input_batch = tuple(data.to(model.device) for data in batch[0])
        y = batch[1]
        mu, z, kl_loss = model.forward_encoder(input_batch)
        z = mu.detach().cpu().numpy()
        z_list.extend(z)
        y_list.append(np.array(y).squeeze())

z_list = np.array(z_list).squeeze()
y_list = np.array(y_list)
y_list = y_list.squeeze()

In [None]:
z_list.shape

In [None]:
# z_viz = TSNE(n_components=2).fit_transform(z_list)

viz = PCA(n_components=2)
z_viz = viz.fit_transform(z_list)

print(z_viz.shape)
z_viz = MinMaxScaler().fit_transform(z_viz)

In [None]:
plt.scatter(z_viz[:,0], z_viz[:,1], c=y_list, cmap='viridis', marker='.', s=10, alpha=0.5, edgecolors='none', )
plt.colorbar()
plt.tight_layout()
plt.show()