In [None]:
# Imports

import os

from jointformer.configs.dataset import DatasetConfig
from jointformer.configs.tokenizer import TokenizerConfig
from jointformer.configs.model import ModelConfig
from jointformer.configs.trainer import TrainerConfig

from jointformer.utils.datasets.auto import AutoDataset
from jointformer.utils.tokenizers.auto import AutoTokenizer
from jointformer.models.auto import AutoModel
from jointformer.trainers.trainer import Trainer

%load_ext autoreload
%autoreload 2

In [None]:
# Configs

REPOSITORY_DIR = '/home/adamizdebski/projects/jointformer'
DATA_DIR = '/home/adamizdebski/files/data'
OUTPUT_DIR = '/home/adamizdebski/files/jointformer/results/pretrain'

PATH_TO_DATASET_CONFIG   = '/home/adamizdebski/projects/jointformer/configs/datasets/guacamol/physchem'
PATH_TO_TOKENIZER_CONFIG = '/home/adamizdebski/projects/jointformer/configs/tokenizers/smiles'
PATH_TO_MODEL_CONFIG = '/home/adamizdebski/projects/jointformer/configs/models/jointformer_test'
PATH_TO_TRAINER_CONFIG = '/home/adamizdebski/projects/jointformer/configs/trainers/test'

In [None]:
os.chdir(REPOSITORY_DIR)

In [None]:
# Test Datsaset

dataset_config = DatasetConfig.from_config_file(PATH_TO_DATASET_CONFIG)
tokenizer_config = TokenizerConfig.from_config_file(PATH_TO_TOKENIZER_CONFIG)

train_dataset = AutoDataset.from_config(dataset_config, data_dir=DATA_DIR, split='train')
val_dataset = AutoDataset.from_config(dataset_config, data_dir=DATA_DIR, split='val')

tokenizer = AutoTokenizer.from_config(tokenizer_config)

In [None]:
# verify dataset

from rdkit import Chem
from tqdm import tqdm
import torch

def verify_dataset(dataset):
    nonvalid_molecule_idx = []
    nonvalid_target_idx = []

    for idx, (smiles, target) in enumerate(tqdm(dataset)):
        try:
            Chem.MolFromSmiles(smiles)
        except:
            nonvalid_molecule_idx.append(idx)
        if not torch.all(target == target):
            nonvalid_target_idx.append(idx) 
    
    return {
        'nonvalid_molecule_idx': nonvalid_molecule_idx,
        'nonvalid_target_idx': nonvalid_target_idx
    }


In [None]:
model_config = ModelConfig.from_config_file(PATH_TO_MODEL_CONFIG)
model = AutoModel.from_config(model_config)

In [None]:
trainer_config = TrainerConfig.from_config_file(PATH_TO_TRAINER_CONFIG)
trainer_config.batch_size = 4

In [None]:

trainer = Trainer(
    config=trainer_config,
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    tokenizer=tokenizer
    )


In [None]:
trainer.train()

In [None]:
model.generate(tokenizer=tokenizer, batch_size=32, device='cuda:0')
