In [54]:
from moses.vae import VAE
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from rdkit import Chem, RDLogger
from pathlib import Path


RDLogger.logger().setLevel(RDLogger.CRITICAL)
RDLogger.DisableLog("rdApp.*")

In [7]:
config_fpath = Path("/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/chemvae/config.txt")
state_dict_fpath = Path("/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/chem_020.pt")

In [9]:
config = torch.load(config_fpath)
vocab = torch.load(config.vocab_save)
state = torch.load(state_dict_fpath)

In [10]:
model = VAE(vocab, config)
model.load_state_dict(state)
model.to("cuda")
model.eval()

device(type='cpu')

In [13]:
all_smiles = list(pd.read_csv("/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/lincs_trapnell.smiles")["smiles"].values)

In [34]:
embeddings = []
for s in tqdm(all_smiles):
    with torch.no_grad():
        tensors = [model.string2tensor(s)]
        emb, _ = model.forward_encoder(tensors)
    embeddings.append(emb.cpu().numpy())

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17869/17869 [00:18<00:00, 987.78it/s]


In [42]:
emb = np.concatenate(embeddings, axis=0)
final_df = pd.DataFrame(emb, index=all_smiles, columns=[f"latent_{i+1}" for i in range(emb.shape[1])])
final_df.to_parquet("/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/chemvae/chemvae.parquet")
final_df

Unnamed: 0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,latent_10,...,latent_119,latent_120,latent_121,latent_122,latent_123,latent_124,latent_125,latent_126,latent_127,latent_128
C[C@H](NC(=O)/C(C#N)=C/c1cccc(Br)n1)c1ccccc1,0.354032,-0.252608,-0.854945,1.369055,1.118283,-0.577468,-0.039789,-0.411299,0.642768,1.659209,...,0.857349,0.910225,-0.521672,0.900717,-1.899955,0.805364,0.475567,-0.337396,1.232203,1.270550
Cc1cc(Nc2cc(CN3CCOCC3)c3nc(C)c(Cc4ccc(Cl)cc4F)n3n2)[nH]n1,-1.508617,-0.939709,-0.857973,-1.199011,-1.474937,-0.667800,0.518754,1.530243,-0.614644,1.259345,...,0.791618,-0.381128,-1.024636,1.030371,0.455687,1.184058,-0.774865,0.132197,0.030353,-1.366459
Cc1cc(N2CCOCC2)cc2[nH]c(-c3c(NCC(O)c4cccc(Cl)c4)cc[nH]c3=O)nc12,0.220337,-0.351990,-1.530201,1.758891,-0.209412,-1.576620,0.354227,0.448785,0.667389,1.443004,...,0.310180,-1.186641,0.526746,-1.588492,0.757901,1.278900,0.548940,-0.206284,1.627874,-1.401685
Cl.Cl.c1ccc([C@@H]2C[C@H]2NC2CCNCC2)cc1,-0.115807,-0.963370,-0.654215,2.280445,-1.091906,-1.344712,0.920827,-0.882062,0.610847,-0.019360,...,1.433155,-0.290291,-0.955461,-1.273009,-1.058183,-1.884115,1.376229,-0.751575,1.640977,-0.175396
O=C(c1ccc(/C=C/c2n[nH]c3ccccc23)cc1)N1CCNCC1,-0.022574,-1.165367,1.746031,0.846614,0.541168,0.276520,-0.274634,1.192463,0.791474,-2.141095,...,0.305412,-0.326080,-1.034308,2.027675,0.662548,0.614320,-0.974802,-0.080513,0.421544,-0.185845
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CCCC(=O)Nc1ccc2c(c1)C(=O)N(C)C[C@H](OC)[C@@H](C)CN(Cc1ccc(-c3ccccn3)cc1)[C@@H](C)CO2,0.058554,1.500882,-0.613799,0.929431,0.627461,-0.593815,0.956860,0.629603,-0.197759,0.011308,...,-1.757150,-1.049584,0.193621,-0.528292,1.204237,0.906834,2.108345,-0.511339,-0.092439,-0.984728
Cc1cc(CS(=O)(=O)c2ccccc2)cc(OCc2ccc(CN3CCC[C@@H]3CO)cc2)c1,0.895613,1.568560,0.635703,-0.366266,-0.298033,-0.981960,-0.019534,1.522135,-0.888550,-0.963127,...,-0.649305,-0.127500,-0.475812,-0.293091,0.184421,-0.273564,-1.854274,-0.506698,1.804127,0.101668
CN(C)CCOc1ccc(/C(=C(\CCCl)c2ccccc2)c2ccccc2)cc1,0.077112,-0.827046,0.272234,-0.483808,-1.268531,0.746058,0.149153,-1.148996,-1.562799,0.234854,...,-0.508319,0.009671,-0.991139,-1.428197,0.653613,-1.160330,0.758727,-0.153587,-1.354894,-0.706485
CC1(C)C=Cc2c(ccc3c2[N+]([O-])=C2C3=C[C@@]34NC(=O)[C@]5(CCCN5C3=O)C[C@H]4C2(C)C)O1,-0.878610,0.454518,0.256789,0.384832,-1.201396,-0.744467,0.778796,-0.436974,-0.311831,1.828151,...,0.547017,-0.144260,0.537531,-1.722388,-0.864747,-0.713463,-0.097453,-0.383025,1.068760,-1.283774


## Bit of testing

Testing sampled SMILES for validitiy

In [47]:
def smiles_is_syntatically_valid(smiles):
    return Chem.MolFromSmiles(smiles, sanitize=False) is not None


def smiles_is_semantically_valid(smiles):
    valid = True
    try:
        Chem.SanitizeMol(Chem.MolFromSmiles(smiles, sanitize=False))
    except:
        valid = False
    return valid

In [52]:
samples = model.sample(1000)

In [57]:
syn_valid = sum(smiles_is_syntatically_valid(s) for s in samples) / len(samples)
sem_valid = sum(smiles_is_syntatically_valid(s) for s in samples) / len(samples)
print(f"TOTAL: {len(samples)} SYN: {syn_valid} SEM: {sem_valid}")

TOTAL: 1000 SYN: 0.871 SEM: 0.871
