In [1]:
# A recipe for finetuning a pre-trained Jointformer model on QED dataset.

In [24]:
# Imports

import os

from jointformer.configs.dataset import DatasetConfig
from jointformer.configs.tokenizer import TokenizerConfig
from jointformer.configs.model import ModelConfig
from jointformer.configs.trainer import TrainerConfig

from jointformer.utils.datasets.auto import AutoDataset
from jointformer.utils.tokenizers.auto import AutoTokenizer
from jointformer.models.auto import AutoModel
from jointformer.trainers.trainer import Trainer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
# Configs

TARGET_LABEL = 'qed'
REPOSITORY_DIR = '/home/adamizdebski/projects/jointformer'
DATA_DIR = '/home/adamizdebski/files/data'
OUTPUT_DIR = f'/home/adamizdebski/files/jointformer/results/jointformer/finetune/{TARGET_LABEL}'

PATH_TO_DATASET_CONFIG   = f'/home/adamizdebski/projects/jointformer/configs/datasets/guacamol/{TARGET_LABEL}'
PATH_TO_TOKENIZER_CONFIG = '/home/adamizdebski/projects/jointformer/configs/tokenizers/smiles'
PATH_TO_MODEL_CONFIG = '/home/adamizdebski/projects/jointformer/configs/models/jointformer_test'
PATH_TO_TRAINER_CONFIG = '/home/adamizdebski/projects/jointformer/configs/trainers/test'

In [26]:
os.chdir(REPOSITORY_DIR)

In [27]:
# Test Datsaset

dataset_config = DatasetConfig.from_config_file(PATH_TO_DATASET_CONFIG)
tokenizer_config = TokenizerConfig.from_config_file(PATH_TO_TOKENIZER_CONFIG)

train_dataset = AutoDataset.from_config(dataset_config, data_dir=DATA_DIR, split='train')
val_dataset = AutoDataset.from_config(dataset_config, data_dir=DATA_DIR, split='val')
test_dataset = AutoDataset.from_config(dataset_config, data_dir=DATA_DIR, split='test')

tokenizer = AutoTokenizer.from_config(tokenizer_config)

In [18]:
# verify dataset

from rdkit import Chem
from tqdm import tqdm
import torch

def verify_dataset(dataset):
    nonvalid_molecule_idx = []
    nonvalid_target_idx = []

    for idx, (smiles, target) in enumerate(tqdm(dataset)):
        try:
            Chem.MolFromSmiles(smiles)
        except:
            nonvalid_molecule_idx.append(idx)
        if not torch.all(target == target):
            nonvalid_target_idx.append(idx) 
    
    return {
        'nonvalid_molecule_idx': nonvalid_molecule_idx,
        'nonvalid_target_idx': nonvalid_target_idx
    }


In [30]:
model_config = ModelConfig.from_config_file(PATH_TO_MODEL_CONFIG)
model = AutoModel.from_config(model_config)

In [31]:
trainer_config = TrainerConfig.from_config_file(PATH_TO_TRAINER_CONFIG)

trainer = Trainer(
    config=trainer_config,
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    test_dataset=test_dataset,
    tokenizer=tokenizer
    )


INFO: Random seed set to 1337
INFO: tokens per iteration set to: 256


In [32]:
trainer.train()

INFO: Evaluation at step 0: train loss 6.3925, val loss 6.4002
INFO: iter 100: loss 5.615761 on prediction task, lr 0.000600, time 259.22ms, mfu 0.00%
INFO: Evaluation at step 200: train loss 4.6265, val loss 4.6266
INFO: Validation loss: 4.6266
INFO: Best validation loss: 1000000000.0000
INFO: Checkpoint updated at iteration 200
INFO: iter 200: loss 4.591779 on prediction task, lr 0.000300, time 8209.45ms, mfu 0.00%
INFO: iter 300: loss 0.018068 on generation task, lr 0.000001, time 237.85ms, mfu 0.00%
INFO: Evaluation at step 400: train loss 4.3897, val loss 4.4003
INFO: Validation loss: 4.4003
INFO: Best validation loss: 4.6266
INFO: Checkpoint updated at iteration 400
INFO: iter 400: loss 0.202759 on generation task, lr 0.000001, time 7853.10ms, mfu 0.00%
INFO: iter 500: loss 4.326705 on prediction task, lr 0.000001, time 244.95ms, mfu 0.00%
INFO: Evaluation at step 600: train loss 4.3938, val loss 4.3895
INFO: Validation loss: 4.3895
INFO: Best validation loss: 4.4003
INFO: Checkp

In [33]:
trainer.test()

0.22091832756996155