# Converting SMILES sequences to tokens, tokens to one hot encoding



In [None]:
import os

import pandas as pd
import numpy as np

import torch

from jak.transformer import XFormer, xformer_train

from jak.data import make_token_dict, \
        sequence_to_vectors, \
        one_hot_to_sequence, \
        tokens_to_one_hot
 


In [None]:
def split_by_smiles(df, test_percentage = 0.1, my_seed=13):
    """
    splits a pandas dataframe according to the 'SMILES' column
    _i.e._ there should be no overlap of molecules in the training and test datasets.     
    """

    npr.seed(my_seed)
    
    unique_smiles = df["SMILES"].unique()
    split_index = npr.rand(*unique_smiles.shape) <= test_percentage
    test_smiles = unique_smiles[split_index]
    train_val_smiles = unique_smiles[1 - split_index]

    test_index = [elem in test_smiles for elem in df["SMILES"]]
    train_val_index = [elem not in test_smiles for elem in df["SMILES"]]

    shared_indices = np.sum(1.0 * np.array(train_val_index) * 1.0*np.array(test_index))
    
    assert shared_indices == 0, f"something went wrong, {shared_indices} shared indices (test data leak)"
    
    test_df = df.loc[test_index]
    train_df = df.loc[train_val_index]
    print(train_df.head())
    return train_df, test_df

In [None]:
smiles_vocab = "#()+-1234567=BCFHINOPS[]cilnors"
token_dict = make_token_dict(smiles_vocab)
df = pd.read_csv("../data/train_JAK.csv")
df.head()

In [None]:
# note, this code works as part of the library at github.com/Cogibra/Sepia
# running this cell alone won't replicate the functionality I used to pre-process the dataset
# unless you first install Sepia

####
# with your virtualenv/conda env/ your favorite env manager env activated
#
# git clone git@github.com:Cogibra/Sepia.git
# cd Sepia
# pip install -e .
# you'll also need jax, and depending on whether you want to use cuda with jax 
# installation can be tricky (cuda and jax versions must match) checkout github.com/google/jax for install instructions
####

import argparse
import os
import copy

import jax
from jax import numpy as jnp
from jax import grad

import numpy as np
import numpy.random as npr

from collections import namedtuple

from sepia.common import query_kwargs
import sepia.optimizer as optimizer

from sepia.seq.data import \
        aa_keys, \
        make_sequence_dict, \
        make_token_dict, \
        tokens_to_one_hot, \
        compose_batch_tokens_to_one_hot, \
        one_hot_to_sequence, \
        vectors_to_sequence, \
        sequence_to_vectors,\
        batch_sequence_to_vectors

# parameters (namedtuples)
"""
from sepia.seq.functional import \
        NICEParametersWB, \
        NICEParametersW, \
        SelfAttentionWB, \
        SelfAttentionW, \
        EncodedAttentionW, \
        EncoderParams, \
        DecoderParams, \
        make_layers_tuple, \
        MLPParams 

# functions
from sepia.seq.functional import \
        encoder, \
        decoder, \
        bijective_forward, \
        bijective_reverse
"""
from sepia.seq.data import \
        make_sequence_dict, \
        vectors_to_sequence, \
        sequence_to_vectors


class SeqDataLoader():

    def __init__(self, token_dict: dict, seq_length: int, token_dim: int,\
            **kwargs: dict):

        self.token_dict = token_dict
        self.seq_length = seq_length
        self.token_dim = token_dim

        self.shuffle = query_kwargs("shuffle", False, **kwargs)
        self.batch_size = query_kwargs("batch_size", 8, **kwargs)
        self.my_seed = query_kwargs("seed", 13, **kwargs)

        if "dataset" in kwargs.keys():
            self.setup_dataset(kwargs["dataset"])

    def setup_dataset(self, dataset: np.array):
        # shape dataset and convert to one hot vectors
        # dataset is expected to be 1D np.array of string sequences
        # (can also convert from list)

        if type(dataset) == list:
            dataset = np.array(dataset)

        if self.shuffle:
            pass

        remainder = dataset.shape[0] % self.batch_size

        while remainder:

            if remainder:
                append_index = self.batch_size - remainder
                dataset = np.append(dataset, dataset[0:append_index], axis=0)

            remainder = dataset.shape[0] % self.batch_size

        dataset = dataset.reshape(-1)

        token_dataset = batch_sequence_to_vectors(dataset, self.token_dict,\
                pad_to = self.seq_length)

        batch_to_one_hot = compose_batch_tokens_to_one_hot(\
                pad_to = self.seq_length, pad_classes_to = self.token_dim)
        one_hot_dataset = batch_to_one_hot(token_dataset)

        self.dataset = one_hot_dataset.reshape(-1, self.batch_size, \
                self.seq_length, self.token_dim)

    def set_dataset(self, dataset: jnp.array):

        remainder = dataset.shape[1] % self.batch_size

        while remainder:

            if remainder:
                append_index = self.batch_size - remainder
                dataset = np.append(dataset, dataset[0:append_index], axis=0)

            remainder = dataset.shape[0] % self.batch_size


        assert self.seq_length == dataset.shape[-2], f"seq_length {self.seq_length} != {dataset.shape[-2]}"
        assert self.token_dim == dataset.shape[-1], f"token_dim {self.token_dim} != {dataset.shape[-1]}"

        self.dataset = dataset.reshape(-1, self.batch_size, \
                self.seq_length, self.token_dim)

    def save_dataset(self, filepath: str=None, reshape: bool=True):

        if filepath is None:
            filepath = os.path.join("data", "temp.npy")
        if reshape:
            save_dataset = self.dataset.reshape(-1, *self.dataset.shape[-2:])
        else:
            save_dataset = self.dataset

        jnp.save(filepath, save_dataset)

    def load_dataset(self, filepath: str=None):

        if filepath is None:
            filepath = os.path.join("data", "temp.npy")

        if os.path.exists(filepath):
            self.set_dataset(jnp.load(filepath))
        else:
            print(f"warning, {filepath} does not exist")

    def __len__(self) -> int:

        return len(self.dataset)

    def __getitem__(self, index) -> jnp.array:

        return self.dataset[index:index+1]

    def __iter__(self):

        if self.shuffle:
            pass

        return iter(self.dataset)

In [None]:
dataset = list(df["SMILES"])

In [None]:
smiles_vocab = "#()+-1234567=BCFHINOPS[]cilnors"
token_dict = make_token_dict(smiles_vocab)

In [None]:
smiles_vocab = "#()+-1234567=BCFHINOPS[]cilnors"
token_dict = make_token_dict(smiles_vocab)
df = pd.read_csv("../data/train_JAK.csv")
df.head()

In [None]:
dataset = list(df["SMILES"])

In [None]:
sdl = SeqDataLoader(token_dict=token_dict, seq_length=128, token_dim=36, batch_size=1)

In [None]:
sdl.setup_dataset(dataset)

In [None]:
dataset_filename = "jak_one_hot_smiles.npy"
sdl.save_dataset(os.path.join("..", "data", dataset_filename))