In [2]:
from datetime import datetime
import tqdm
import torch
from orquestra.qml.models.rbm.th import RBM, TrainingParameters as RBMParams
import sys,os
sys.path.append(os.path.join("../"))
from utils import Experiment, SelfiesEncoding
from models import MolPAT
from pathlib import Path
from orquestra.qml.data_loaders import new_data_loader
from orquestra.qml.api import (
    TorchGenerativeModel,
    GenerativeModel,
    Callback,
    TrainCache,
    convert_to_numpy,
    GenerativeModel,
)

  Chem.MolFromSmarts(x) for x in _mcf.append(_pains, sort=True)["smarts"].values


In [3]:
run_date_time = datetime.today().strftime("%Y_%d_%mT%H_%M")
experiment = Experiment(run_id=f"MolPAT-{run_date_time}")
print(experiment)

Experiment(run_id=MolPAT-2023_23_03T18_22, path=/Users/mohamad/workspace/repos/insilico-drug-discovery/notebooks/experiment_results/MolPAT-2023_23_03T18_22)


In [6]:
# GLOBAL EXPERIMENT VARIABLES
experiment.path_to_dataset = "../data/KRAS_G12D/KRAS_G12D_inhibitors_update2023.csv"
experiment.dataset_id = "insilico_KRAS"

# TRAINING VARIABLES
experiment.n_epochs = 1
experiment.batch_size = 32
experiment.seed = 1000
experiment.set(n_prior_epochs=20)

In [7]:
# initialize the SMILES or SELFIES object
selfies = SelfiesEncoding(
    experiment.path_to_dataset,
    dataset_identifier=experiment.dataset_id
)

In [8]:
# can add function to RBM to return the configuration
prior = RBM(
    n_visible_units=10, 
    n_hidden_units=10, 
    training_parameters=RBMParams()
)

prior_config = dict(
    name=prior.name,
    n_visible_units=prior.n_visible_units,
    n_hidden_units=prior.n_hidden_units,
    training_parameters=prior.training_parameters.__dict__
)

experiment.model_configurations.append(prior_config)

In [9]:
model = MolPAT(
    vocab_size=selfies.num_emd,
    seq_len=selfies.max_length,
    start_token_index=selfies.start_char_index,
    prior_dim=prior.sample_size[-1],
    padding_token_index=selfies.pad_char_index,
    hidden_dim=128,
    embedding_dim=256,
)

experiment.model_configurations.append(model.config.as_dict())

In [10]:
dummy_input_data = torch.zeros((1, model.seq_len)).long()
dummy_prior_samples = prior.generate(1).float()
model.summary(input_data=(dummy_input_data, dummy_prior_samples), col_names=["input_size", "output_size"], depth=4);

Layer (type (var_name):depth-idx)                            Input Shape               Output Shape
_DiscretePATModel                                            --                        --
├─TransformerEncoder (encoder): 1                            --                        --
│    └─ModuleList (layers): 2-1                              --                        --
├─Embedding (embedding): 1-1                                 [1, 835]                  [1, 835, 256]
├─Concatenate (concatenate): 1-2                             [1, 835, 256]             [1, 835, 266]
├─Linear (pre_pe_projection): 1-3                            [1, 835, 266]             [1, 835, 128]
├─PositionalEncoding (positional_encoding): 1-4              [1, 835, 128]             [1, 835, 128]
│    └─Dropout (dropout): 2-2                                [1, 835, 128]             [1, 835, 128]
├─TransformerEncoder (encoder): 1                            --                        --
│    └─ModuleList (layers): 2-3    

In [11]:
sample = prior.generate(5).float()
mols = model.generate(sample)

In [12]:
n_epochs = 10
batch_size = 32
n_test_samples = 20000

In [13]:
encoded_samples_th = torch.tensor(selfies.encoded_samples)
data = encoded_samples_th.long()

In [14]:
epoch_plot_dir = Path("experiment_results") / "epoch_plots" / experiment.run_id
epoch_plot_dir = epoch_plot_dir.resolve()

if epoch_plot_dir.exists() is False:
    os.makedirs(str(epoch_plot_dir))


dataloader = new_data_loader(
    data=data, batch_size = batch_size
).shuffle(12345)
train_cache = TrainCache()

In [None]:
generated_compunds = {}
live_model_loss = []
for epoch in range(1, n_epochs + 1):
    with tqdm.tqdm(total=dataloader.n_batches) as pbar:
        pbar.set_description(f"Epoch {epoch} / {n_epochs}.")
        concat_prior_samples = []
        for batch_idx, batch in enumerate(dataloader):
            prior_samples = prior.generate(batch.batch_size).long()
            batch.targets = prior_samples
            batch_result = model.train_on_batch(batch)
            train_cache.update_history(batch_result)
            concat_prior_samples = concat_prior_samples + prior_samples.tolist()

            pbar.set_postfix(dict(Loss=batch_result["loss"]))
            pbar.update()
            # if torch.cuda.is_available():
            #     torch.cuda.empty_cache()
        th_prior_samples = torch.tensor(concat_prior_samples)
        # put model in evaluation mode such that layers like Dropout, Batchnorm don't affect results
        model.set_eval_state()



        

In [22]:
model.set_train_state()

In [28]:
P0 = prior.generate(1).float()


In [33]:
S0 = model.generate(P0)

In [34]:
selfies.decode_fn(S0)

['C1COC(=C)C2CC1(C(=C)N=C(C2C3C)Br)CN=C3(CC[C@@H](F))']