In [2]:
import torch
%cd ..

/home/adam/Projects/jointformer


In [160]:
import random

from jointformer.utils.datasets.smiles.base import SmilesDataset
from jointformer.utils.datasets.smiles.guacamol import GuacamolSmilesDataset
from jointformer.utils.transforms.smiles.enumerate import SmilesEnumerator
from jointformer.utils.tokenizers.smiles.smiles import SmilesTokenizer
from jointformer.models.base import Transformer

import torchvision.transforms as transforms

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [193]:
# Data // follows PyTorch API :) 

PATH_TO_DATA = 'data/guacamol/test/smiles.txt'
PATH_TO_VOCAB = 'data/vocabularies/deepchem.txt'
NUM_SAMPLES = 1000
MAX_LENGTH = 128
VALIDATE = False

transform = transforms.Compose([
    SmilesEnumerator(),
])

dataset = GuacamolSmilesDataset(split='all')
tokenizer = SmilesTokenizer(path_to_vocabulary=PATH_TO_VOCAB, max_molecule_length=MAX_LENGTH)
model = Transformer(
    vocab_size=len(tokenizer), max_seq_len=tokenizer.max_molecule_length, embedding_dim=32, dropout=0.2,
    num_layers=2, bias=False, num_heads=8
)

number of parameters: 0.04M


In [194]:
bs = 2
idx = random.sample([idx for idx in range(len(dataset))], bs)
smiles = [dataset[i] for i in idx] # custom DL :')
inputs = tokenizer(smiles, task='lm') # 'lm', 'mlm', 'predict', 'ae'
print(f"Input keys: {inputs.keys()}")
with torch.no_grad():
    outputs = model(**inputs, task='lm')
    print(f"Output keys: {outputs.keys()}")

Input keys: dict_keys(['input_ids', 'attention_mask', 'labels'])
Output keys: dict_keys(['embeddings', 'attention_probabilities'])


In [195]:
# Task 1, GuacaMol and general properties like QED







In [196]:
tokens

{'(', ')', '1', '=', 'C', 'N', 'O', 'c'}

In [206]:
vocabulary = set()

for idx, x in enumerate(dataset):
    tokens = set(tokenizer.basic_tokenizer._split_into_tokens(x))
    vocabulary = vocabulary | tokens
    if 'p' in tokens:
        print(x)

CCOC(=O)c1pc(P(Cl)Cl)c2n1C(N=Nc1ccc(O)cc1)C(=O)Nc1ccc(C)cc1-2
CCOC(=O)c1pc(P(Cl)Cl)c2n1C(N=Nc1cccc(O)c1)C(=O)Nc1ccc(C)cc1-2
Cc1cccc(CP(Cc2cccc(C)c2)CC2(Cp3c4ccccc4c4ccccc43)COC2)c1
CCOC(=O)C1(C(=O)OCC)C(Cl)C(=O)N1N(c1c(O)ccc2c(P(Cl)Cl)pc(-c3ccccc3)n12)[N+](=O)[O-]
CCOC(=O)C1(C(=O)CBr)C(Cl)C(=O)N1N(c1c(O)ccc2c(P(Cl)Cl)pc(C(=O)O)n12)[N+](=O)[O-]
CCOC(=O)c1pc(P(Cl)Cl)c2n1C(N=Nc1ccc([N+](=O)[O-])cc1)C(=O)Nc1ccc(C)cc1-2
CCOC(=O)c1pc(P(Cl)Cl)c2n1C(N=Nc1cccc(C(=O)O)c1)C(=O)Nc1ccc(C)cc1-2
CCOC(=O)c1pc(P(Cl)Cl)c2n1C(N=Nc1ccc(C)cc1)C(=O)Nc1ccc(C)cc1-2
CCOC(=O)c1pc(P(Cl)Cl)c2n1C(N=Nc1cccc(Cl)c1)C(=O)Nc1ccc(C)cc1-2
CCOC(=O)c1pc(P(Cl)Cl)c2n1C(N=Nc1ccc(CC)cc1)C(=O)Nc1ccc(C)cc1-2
COC(=O)C1(C(C)=O)C(Cl)C(=O)N1N(c1c(O)ccc2c(P(Cl)Cl)pc(-c3ccccc3)n12)[N+](=O)[O-]
CC(=O)C1(C(C)=O)C(Cl)C(=O)N1N(c1c(O)ccc2c(P(Cl)Cl)pc(-c3ccccc3)n12)[N+](=O)[O-]
CCOC(=O)C1(C(=O)OC)C(Cl)C(=O)N1N(c1c(O)ccc2c(P(Cl)Cl)pc(C(=O)O)n12)[N+](=O)[O-]
CCOC(=O)c1pc(P(Cl)Cl)c2n1C(N=Nc1cccc(CC)c1)C(=O)Nc1ccc(C)cc1-2
CCOC(=O)c1pc(P(Cl)Cl)c

KeyboardInterrupt: 

In [205]:
vocabulary

{'%10',
 '%11',
 '%12',
 '(',
 ')',
 '-',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '=',
 'B',
 'Br',
 'C',
 'Cl',
 'F',
 'I',
 'N',
 'O',
 'P',
 'S',
 '[B-]',
 '[BH-]',
 '[BH2-]',
 '[BH3-]',
 '[B]',
 '[Br+2]',
 '[Br-]',
 '[C+]',
 '[C-]',
 '[CH+]',
 '[CH-]',
 '[CH2+]',
 '[CH2]',
 '[CH]',
 '[Cl+2]',
 '[Cl+3]',
 '[Cl+]',
 '[Cl-]',
 '[F+]',
 '[F-]',
 '[H]',
 '[I+2]',
 '[I+3]',
 '[I+]',
 '[I-]',
 '[IH2]',
 '[N+]',
 '[N-]',
 '[NH+]',
 '[NH-]',
 '[NH2+]',
 '[NH3+]',
 '[N]',
 '[O+]',
 '[O-]',
 '[OH+]',
 '[O]',
 '[P+]',
 '[P-]',
 '[PH2+]',
 '[PH]',
 '[S+]',
 '[S-]',
 '[SH+]',
 '[SH-]',
 '[SH]',
 '[Se+]',
 '[Se-]',
 '[SeH2]',
 '[SeH]',
 '[Se]',
 '[Si-]',
 '[SiH-]',
 '[SiH2]',
 '[SiH]',
 '[Si]',
 '[b-]',
 '[c+]',
 '[c-]',
 '[cH+]',
 '[cH-]',
 '[n+]',
 '[n-]',
 '[nH+]',
 '[nH]',
 '[o+]',
 '[s+]',
 '[se+]',
 '[se]',
 'b',
 'c',
 'n',
 'o',
 'p',
 's'}