In [1]:
from datasets import load_dataset
from molT import MolTConfig, MolTTokenizer, MolTForMaskedMM, DataCollatorForMaskedMolecularModeling
from functools import partial

from transformers import Trainer, TrainingArguments

In [2]:
# import rdkit.Chem as Chem
# smiles = 'CC'
# from rdkit.Chem.Descriptors import CalcMolDescriptors
# mol = Chem.MolFromSmiles(smiles)
# CalcMolDescriptors(mol)

In [3]:
model_config = MolTConfig()
tokenizer = MolTTokenizer(model_config)
model = MolTForMaskedMM(model_config)

In [4]:
def tokenize(entry, tokenizer):
    return tokenizer(entry['smiles'], truncation=False, return_attention_mask=True, return_special_tokens_mask=True)

def load_data(tokenizer):
    ds = load_dataset("sagawa/ZINC-canonicalized")['validation'].select(range(5000)).train_test_split()
    tok_func = partial(tokenize, tokenizer=tokenizer)
    ds = ds.map(tok_func, num_proc=8)
    return ds

ds = load_data(tokenizer)

Map (num_proc=8):   0%|          | 0/3750 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/1250 [00:00<?, ? examples/s]

In [5]:
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForMaskedMolecularModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [6]:
from sklearn.metrics import r2_score
from transformers.trainer_utils import SchedulerType
from transformers import TrainerCallback

def fn_metrics(eval_pred):
    mm_loss, atom_prop_loss, bond_prop_loss, mol_desc_loss = eval_pred.predictions
    return {
        "mm_loss": mm_loss.mean(),
        'atom_prop_loss': atom_prop_loss.mean(),
        'bond_prop_loss': bond_prop_loss.mean(),
        'mol_desc_loss': mol_desc_loss.mean()
    }

training_args = TrainingArguments(
    output_dir="molT_runs",
    evaluation_strategy="steps",
    learning_rate=2e-4,
    num_train_epochs=4,
    weight_decay=0.01,
    push_to_hub=False,
    logging_steps=2,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,  
    eval_steps=4,
    gradient_accumulation_steps=64,
    warmup_ratio=0.2,
    report_to='tensorboard',
    # dataloader_num_workers=4,
    # lr_scheduler_type=SchedulerType.COSINE_WITH_RESTARTS,
    data_seed=42,
    # label_names = ['reg'],
    # load_best_model_at_end = True,
    # metric_for_best_model = "eval_loss",
    # max_grad_norm=0.5
    # greater_is_better = False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    data_collator=data_collator,
    compute_metrics=fn_metrics,
)

trainer.train()

  0%|          | 0/4 [00:00<?, ?it/s]

{'loss': 21.8381, 'learning_rate': 0.00013333333333333334, 'epoch': 1.08}
{'loss': 69.4462, 'learning_rate': 0.0, 'epoch': 2.17}


  0%|          | 0/40 [00:00<?, ?it/s]

{'eval_loss': 28.63437843322754, 'eval_mm_loss': 2.345188856124878, 'eval_atom_prop_loss': 2.382370710372925, 'eval_bond_prop_loss': 1.6494123935699463, 'eval_mol_desc_loss': 2179.296875, 'eval_runtime': 9.466, 'eval_samples_per_second': 132.052, 'eval_steps_per_second': 4.226, 'epoch': 2.17}
{'train_runtime': 152.2793, 'train_samples_per_second': 98.503, 'train_steps_per_second': 0.026, 'train_loss': 45.64213275909424, 'epoch': 2.17}


TrainOutput(global_step=4, training_loss=45.64213275909424, metrics={'train_runtime': 152.2793, 'train_samples_per_second': 98.503, 'train_steps_per_second': 0.026, 'train_loss': 45.64213275909424, 'epoch': 2.17})