In [1]:
import os
from functools import partial

from datasets import load_dataset
from transformers import Trainer, TrainingArguments
from transformers.trainer_utils import SchedulerType

from data import generate_and_scale_mol_descriptors
from molT import (
    DataCollatorForMaskedMolecularModeling,
    MolTConfig,
    MolTForMaskedMM,
    MolTTokenizer,
)

from utils import download_model_from_wandb

import numpy as np
from molT.utils import TokenType
from rdkit import Chem

In [2]:
def tokenize(entry, tokenizer):
    entry = dict(entry)
    smiles = entry.pop("smiles")
    return tokenizer(
        smiles,
        truncation=False,
        return_attention_mask=True,
        return_special_tokens_mask=True,
        **entry,
    )

In [3]:
model_config = MolTConfig(
    atom_bond_mask_probability=0.15,
    molecule_feature_mask_probability=0.15,
    use_mol_descriptor_tokens=True,
    use_target_token=False,
)
tokenizer = MolTTokenizer(model_config)

tok_func = partial(tokenize, tokenizer=tokenizer)
ds = load_dataset("sagawa/ZINC-canonicalized")["validation"].select(range(100)).train_test_split(seed=42)
ds, _ = generate_and_scale_mol_descriptors(
    ds, model_config.mol_descriptors, num_samples=50, num_proc=None
)

In [4]:
ds = ds.map(tok_func, num_proc=None)
ds = ds.remove_columns(set(ds['train'].column_names) - set(tokenizer.model_input_names))

Map:   0%|          | 0/75 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

In [7]:
from torch.utils.data import DataLoader
from molT.collator import DataCollatorForMaskedMolecularModeling

dl = DataLoader(ds['train'], batch_size=8, collate_fn=DataCollatorForMaskedMolecularModeling(tokenizer, model_config))

In [8]:
next(iter(dl))

{'token_ids': tensor([[ 0,  0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,
          8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17,
         17, 18, 18, 19, 19, 20, 21,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,
          8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17,
         17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 26, 27,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,
          8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17,
         17, 18, 18,