In [2]:
import moses
from moses.vae import VAE
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from rdkit import Chem, RDLogger
from pathlib import Path


RDLogger.logger().setLevel(RDLogger.CRITICAL)
RDLogger.DisableLog("rdApp.*")

In [5]:
config_fpath = Path("/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/chemvae/config.txt")
state_dict_fpath = Path("/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/chemvae/vae_checkpoint_final.pt")

In [6]:
config = torch.load(config_fpath)
vocab = torch.load(config.vocab_save)
state = torch.load(state_dict_fpath)

In [7]:
model = VAE(vocab, config)
model.load_state_dict(state)
model.to("cuda")
model.eval()

VAE(
  (x_emb): Embedding(50, 50, padding_idx=48)
  (encoder_rnn): GRU(50, 256, batch_first=True)
  (q_mu): Linear(in_features=256, out_features=128, bias=True)
  (q_logvar): Linear(in_features=256, out_features=128, bias=True)
  (decoder_rnn): GRU(178, 512, num_layers=3, batch_first=True)
  (decoder_lat): Linear(in_features=128, out_features=512, bias=True)
  (decoder_fc): Linear(in_features=512, out_features=50, bias=True)
  (encoder): ModuleList(
    (0): GRU(50, 256, batch_first=True)
    (1): Linear(in_features=256, out_features=128, bias=True)
    (2): Linear(in_features=256, out_features=128, bias=True)
  )
  (decoder): ModuleList(
    (0): GRU(178, 512, num_layers=3, batch_first=True)
    (1): Linear(in_features=128, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=50, bias=True)
  )
  (vae): ModuleList(
    (0): Embedding(50, 50, padding_idx=48)
    (1): ModuleList(
      (0): GRU(50, 256, batch_first=True)
      (1): Linear(in_features=256, out_featu

In [8]:
all_smiles = list(pd.read_csv("/storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/lincs_trapnell.smiles")["smiles"].values)

In [9]:
embeddings = []
for s in tqdm(all_smiles):
    with torch.no_grad():
        tensors = [model.string2tensor(s)]
        emb, _ = model.forward_encoder(tensors)
    embeddings.append(emb.cpu().numpy())

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17869/17869 [00:24<00:00, 742.18it/s]


In [10]:
emb = np.concatenate(embeddings, axis=0)
final_df = pd.DataFrame(emb, index=all_smiles, columns=[f"latent_{i+1}" for i in range(emb.shape[1])])
final_df.to_parquet("chemvae.parquet")
final_df

Unnamed: 0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,latent_10,...,latent_119,latent_120,latent_121,latent_122,latent_123,latent_124,latent_125,latent_126,latent_127,latent_128
C[C@H](NC(=O)/C(C#N)=C/c1cccc(Br)n1)c1ccccc1,1.557116,0.449109,-0.212241,0.265395,1.368745,-0.865343,0.090169,-0.367736,0.105361,0.530865,...,0.242221,-0.231116,-2.729977,0.919116,-0.453863,0.219208,0.485712,0.634920,0.439277,0.078147
Cc1cc(Nc2cc(CN3CCOCC3)c3nc(C)c(Cc4ccc(Cl)cc4F)n3n2)[nH]n1,1.067754,0.533023,0.595420,0.754059,0.904944,0.480257,-0.454201,-0.470015,-0.320997,-1.815178,...,-1.151122,-1.349364,0.794603,-1.816965,-0.291386,2.363030,-0.085799,-0.386940,0.036858,-0.809442
Cc1cc(N2CCOCC2)cc2[nH]c(-c3c(NCC(O)c4cccc(Cl)c4)cc[nH]c3=O)nc12,0.607472,1.015737,-0.442386,-1.595294,0.497546,0.812311,1.622832,-1.344401,-0.495128,-0.564234,...,-0.843197,-0.461671,1.076396,-0.859824,0.022334,1.221702,-0.349340,0.164777,-1.027404,0.745297
Cl.Cl.c1ccc([C@@H]2C[C@H]2NC2CCNCC2)cc1,1.902807,0.023736,0.216138,-0.263214,-1.295511,-0.631556,0.723728,1.488993,-2.693938,0.349539,...,-2.564500,-1.536878,0.546583,0.523096,-0.223104,1.028180,-0.352723,-1.239876,0.246064,1.211902
O=C(c1ccc(/C=C/c2n[nH]c3ccccc23)cc1)N1CCNCC1,-2.072833,0.183836,1.759904,0.790609,-0.334443,-1.022856,-1.063119,-0.542503,-0.627536,0.971684,...,0.870260,-0.474700,0.221147,-1.369071,0.712489,-1.004885,0.844010,0.255219,0.302455,0.778227
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CCCC(=O)Nc1ccc2c(c1)C(=O)N(C)C[C@H](OC)[C@@H](C)CN(Cc1ccc(-c3ccccn3)cc1)[C@@H](C)CO2,-0.198073,0.768631,-0.656850,-0.485141,-0.118598,-0.862468,1.554209,-0.019635,-1.222284,0.432182,...,-0.025859,1.828059,0.993634,-0.957795,1.776462,0.822132,-1.658626,-1.435088,0.541937,1.202558
Cc1cc(CS(=O)(=O)c2ccccc2)cc(OCc2ccc(CN3CCC[C@@H]3CO)cc2)c1,-0.332620,-0.690646,-0.650203,-0.319614,0.423077,0.155828,-1.601239,-1.608535,-2.804836,-0.260143,...,0.462508,-0.834634,1.057046,-0.566899,1.631578,1.147480,2.037716,1.406872,-0.171140,-0.583970
CN(C)CCOc1ccc(/C(=C(\CCCl)c2ccccc2)c2ccccc2)cc1,-0.633693,-0.070029,-0.264711,0.803552,1.716896,0.495845,-1.544559,2.549381,-1.420597,0.407462,...,-0.858744,1.012648,-0.314032,0.011283,-0.380853,2.732352,-0.685940,-0.252476,-0.519152,0.563595
CC1(C)C=Cc2c(ccc3c2[N+]([O-])=C2C3=C[C@@]34NC(=O)[C@]5(CCCN5C3=O)C[C@H]4C2(C)C)O1,0.842776,0.310461,-1.037071,-0.107875,0.680730,0.074522,0.323799,-0.014356,-0.188539,-0.145040,...,-1.735851,2.686139,-1.082980,0.094797,1.601195,1.199945,2.200440,1.683513,-0.571433,0.032320


## Bit of testing

Testing sampled SMILES for validitiy

In [11]:
def smiles_is_syntatically_valid(smiles):
    return Chem.MolFromSmiles(smiles, sanitize=False) is not None


def smiles_is_semantically_valid(smiles):
    valid = True
    try:
        Chem.SanitizeMol(Chem.MolFromSmiles(smiles, sanitize=False))
    except:
        valid = False
    return valid

In [12]:
samples = model.sample(1000)

  x[~eos_mask, i] = w[~eos_mask]
  end_pads[i_eos_mask] = i + 1


In [13]:
syn_valid = sum(smiles_is_syntatically_valid(s) for s in samples) / len(samples)
sem_valid = sum(smiles_is_syntatically_valid(s) for s in samples) / len(samples)
print(f"TOTAL: {len(samples)} SYN: {syn_valid} SEM: {sem_valid}")

TOTAL: 1000 SYN: 0.954 SEM: 0.954


In [14]:
samples

['O=C(COC(=O)c1csc2c1CCCC2)[C@H]1C[C@@H](O)[C@H](O)C[C@H]1C(=O)O)C(C)(C)C',
 'CC(C)c1nc(CNC(=O)c2ccc(-n3cccc3)cc2F)cs1',
 'CN(C)C(=O)c1ccc(S(=O)(=O)NC[C@H](c2cccc3ccccc23)[NH+]2CCCC2)cc1',
 'C[C@H]1CC(C(=O)Nc2cc(C(=O)N(C)C)ccc2F)C[C@@H](C)O1',
 'COc1ccc2c3c([nH]c2c1)[C@@H](CO)N(C(=O)c1ccncc1)CC31CN(C)C1',
 'CCOc1ccc([C@@H]2CCCN2C(=O)N[C@H](C)c2cn(C)nc2C)cc1',
 'CC(=O)c1c(C)[nH]c(C(=O)[C@H]2Cc3cc(C)c(C)cc3O2)c1C',
 'CCc1nn(C)cc1NC(=O)C1C[C@H](C)O[C@H](C)C1',
 'C[NH+](C)[C@H](CNC(=O)c1cccc(NC(=O)C2CCCCC2)c1)c1ccccc1',
 'Oc1ccc(-c2csc(C3CCCC3)c2)[nH]1',
 'Cc1cc(C)cc(C2=CCN(C(=O)c3csc(Cc4cccs4)n3)CC2)c1',
 'COc1ccc(N2CCOCC2)cc1COC(C)C',
 'O=C(c1ccn(Cn2cccn2)n1)N1CCC[C@@H](c2nnc3n2CCCC3)C1',
 'C=CCNC(=O)c1csc(Cc2ccc(F)cc2)n1',
 'OCC[C@H]1CN(Cc2ccccc2)CC[NH+]1C[C@H]1CCCO1',
 'CCCN(CC(N)=O)C(=O)C1C(C)(C)C1(C)C',
 'O=C(NCc1ccc(Cl)cc1)c1cccc(S(=O)(=O)N2CCNC2)c1',
 'Cc1ccc(NC(=O)COc2ccc3c(=O)c4ccccc4o3)cc2c1',
 'C#CC[C@H]([NH2+]C1CCC(O)CC1)c1cccs1',
 'CC1(C#N)CS[C@H]1c1ccccc1',
 'CC[C@@H](C)[C@H