In [1]:
from datasets import load_dataset
from sklearn import metrics
from transformers import Trainer, TrainingArguments
from transformers.trainer_utils import SchedulerType

from data import generate_and_scale_mol_descriptors
from molT import (
    DataCollatorForMaskedMolecularModeling,
    MolTConfig,
    MolTForMaskedMM,
    MolTTokenizer,
)

In [2]:
# import rdkit.Chem as Chem
# smiles = 'CC'
# from rdkit.Chem.Descriptors import CalcMolDescriptors
# mol = Chem.MolFromSmiles(smiles)
# CalcMolDescriptors(mol)

In [3]:
model_config = MolTConfig()
tokenizer = MolTTokenizer(model_config)
model = MolTForMaskedMM(model_config)

In [4]:
def tokenize(entry, tokenizer):
    entry = dict(entry)
    smiles = entry.pop("smiles")
    return tokenizer(
        smiles,
        truncation=False,
        return_attention_mask=True,
        return_special_tokens_mask=True,
        **entry,
    )

ds = (
    load_dataset("sagawa/ZINC-canonicalized")["validation"]
    .select(range(300_000))
    .train_test_split(seed=42)
)

ds, _ = generate_and_scale_mol_descriptors(
    ds, model_config.mol_descriptors, num_samples=100_000, num_proc=16
)

Map (num_proc=8):   0%|          | 0/3750 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/1250 [00:00<?, ? examples/s]

In [None]:
def fn_metrics(eval_pred):
    (
        mm_loss,
        atom_prop_loss,
        bond_prop_loss,
        mol_desc_loss,
        target_loss,
        target_mask,
        pred_target_values,
        true_target_values,
    ) = eval_pred.predictions

    target_mask[target_mask == -100] = 0
    target_mask = target_mask.astype(bool)
    pred_target_values[~target_mask] = 0.0
    true_target_values[~target_mask] = 0.0

    y_true = true_target_values[target_mask]
    y_pred = pred_target_values[target_mask]

    r2_score = metrics.r2_score(y_true, y_pred)
    mae = metrics.mean_absolute_error(y_true, y_pred)
    mse = metrics.mean_squared_error(y_true, y_pred)

    return {
        "mm_loss": mm_loss.mean(),
        "atom_prop_loss": atom_prop_loss.mean(),
        "bond_prop_loss": bond_prop_loss.mean(),
        "mol_desc_loss": mol_desc_loss.mean(),
        "target_loss": target_loss.mean(),
        "target_r2": r2_score,
        "target_mae": mae,
        "target_mse": mse,
    }

In [6]:
training_args = TrainingArguments(
    output_dir="molT_runs",
    evaluation_strategy="steps",
    learning_rate=2e-4,
    num_train_epochs=4,
    weight_decay=0.01,
    push_to_hub=False,
    logging_steps=8,
    eval_steps=32,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    gradient_accumulation_steps=16,
    warmup_ratio=0.1,
    report_to="wandb",
    dataloader_num_workers=8,
    lr_scheduler_type=SchedulerType.COSINE,
    data_seed=42,
    run_name="molt_dev_v2",
    dataloader_pin_memory=True,
    bf16=True,
    bf16_full_eval=True,
)

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForMaskedMolecularModeling(tokenizer=tokenizer, mlm_probability=0.15)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    data_collator=data_collator,
    compute_metrics=fn_metrics,
)

trainer.train()

  0%|          | 0/4 [00:00<?, ?it/s]

{'loss': 21.8381, 'learning_rate': 0.00013333333333333334, 'epoch': 1.08}
{'loss': 69.4462, 'learning_rate': 0.0, 'epoch': 2.17}


  0%|          | 0/40 [00:00<?, ?it/s]

{'eval_loss': 28.63437843322754, 'eval_mm_loss': 2.345188856124878, 'eval_atom_prop_loss': 2.382370710372925, 'eval_bond_prop_loss': 1.6494123935699463, 'eval_mol_desc_loss': 2179.296875, 'eval_runtime': 9.466, 'eval_samples_per_second': 132.052, 'eval_steps_per_second': 4.226, 'epoch': 2.17}
{'train_runtime': 152.2793, 'train_samples_per_second': 98.503, 'train_steps_per_second': 0.026, 'train_loss': 45.64213275909424, 'epoch': 2.17}


TrainOutput(global_step=4, training_loss=45.64213275909424, metrics={'train_runtime': 152.2793, 'train_samples_per_second': 98.503, 'train_steps_per_second': 0.026, 'train_loss': 45.64213275909424, 'epoch': 2.17})