## Create a Memmap dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [22]:
import pandas as pd
import vae_cyc as vc
import numpy as np
import torch
import pytorch_lightning as pl
from tqdm import tqdm 
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger

In [3]:
df = pd.read_parquet('/home/jovyan/data/random_subset_1.parquet')
df = df.dropna()
max_len = 128
smiles_col = 'smiles'
vocab = vc.Vocab.load('zinc-vocab.json')

In [4]:
testset = pd.read_csv('/home/jovyan/data/testset.csv')

In [5]:
def process_df(df, smiles_col):
    all_tokens = np.zeros(shape=(len(df), max_len), dtype=np.int8)
    for idx, smi in enumerate(tqdm(df[smiles_col].values)):
        smi = vocab.encode_special(smi)
        tokens = [vocab.sos_idx] + [vocab.char2idx[i] for i in smi] + [vocab.eos_idx]
        tokens = tokens + [vocab.pad_idx] * (max_len - len(tokens)) 
        tokens = np.array(tokens, dtype=np.int8)
        all_tokens[idx] = tokens 
    return all_tokens

In [14]:
def save_memmap(vectors, fname, split, dtype='int8'):
    import json
    meta = {'shape':vectors.shape, 'dtype':'int8'}
    with open(f'{fname}_{split}.json', 'w') as f:
        json.dump(meta, f)
    data = np.memmap(f'{fname}_{split}.dat', dtype=dtype, mode='w+', shape=vectors.shape)
    data[:] = vectors 
    data.flush()

In [18]:
len(df)

99999976

In [7]:
save_memmap(process_df(df, 'smiles'), 'random_subset_1', 'train')

100%|██████████| 99999976/99999976 [19:00<00:00, 87695.23it/s]


In [9]:
save_memmap(process_df(testset, 'smiles'), 'random_subset_1', 'val')

100%|██████████| 9736/9736 [00:00<00:00, 53819.05it/s]


In [None]:
train_ds = vc.TransformerMemapDataset('random_subset_1_train', vocab=vocab)

val_ds = vc.TransformerMemapDataset('random_subset_1_val', vocab=vocab)

In [None]:
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=128, collate_fn=ds.collate,num_workers=45)
val_dl = torch.utils.data.DataLoader(val_ds,batch_size=128, collate_fn=ds.collate,num_workers=45)

wandb_logger = WandbLogger(project="run_pod", log_model=True, name='transformer-memmap-moses')

trainer = pl.Trainer(devices=[2], logger=wandb_logger, precision=16)
model = vc.Transformer(vocab_size=len(vocab.char2idx), 
                       num_heads=10, 
                       hidden_dim=128, 
                       num_layers=10, 
                       embed_size=200, 
                       latent_dim=64, 
                       vocab=vocab, 
                       max_seq_length=129, lr=0.0001)

trainer.fit(model, train_dl, val_dl)

model = model.eval()