# Jointformer Training

This notebook shows how to train Jointformer. 

In [6]:
# Imports

import os

from jointformer.configs.dataset import DatasetConfig
from jointformer.configs.tokenizer import TokenizerConfig
from jointformer.configs.model import ModelConfig
from jointformer.configs.trainer import TrainerConfig

from jointformer.utils.datasets.auto import AutoDataset
from jointformer.utils.tokenizers.auto import AutoTokenizer
from jointformer.models.auto import AutoModel
from jointformer.trainers.trainer import Trainer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
# Set working directory of the project

REPOSITORY_DIR = '/home/adamizdebski/projects/jointformer'
os.chdir(REPOSITORY_DIR)

In [8]:
# Configs

DATA_DIR = '/home/adamizdebski/files/data'
OUTPUT_DIR = '/home/adamizdebski/files/jointformer/results/finetune'

PATH_TO_DATASET_CONFIG   = '/home/adamizdebski/projects/jointformer/configs/datasets/guacamol/physchem'
PATH_TO_TOKENIZER_CONFIG = '/home/adamizdebski/projects/jointformer/configs/tokenizers/smiles_separate_task_token'
PATH_TO_MODEL_CONFIG = '/home/adamizdebski/projects/jointformer/configs/models/jointformer_v2'
PATH_TO_TRAINER_CONFIG = '/home/adamizdebski/projects/jointformer/configs/trainers/test'

In [9]:
# Test Datsaset

dataset_config = DatasetConfig.from_config_file(PATH_TO_DATASET_CONFIG)
tokenizer_config = TokenizerConfig.from_config_file(PATH_TO_TOKENIZER_CONFIG)

train_dataset = AutoDataset.from_config(dataset_config, data_dir=DATA_DIR, split='train')
val_dataset = AutoDataset.from_config(dataset_config, data_dir=DATA_DIR, split='val')

tokenizer = AutoTokenizer.from_config(tokenizer_config)

In [13]:
# Init Jointformer

model_config = ModelConfig.from_config_file(PATH_TO_MODEL_CONFIG)
model = AutoModel.from_config(model_config, downstream_task=dataset_config.task_type, num_tasks=dataset_config.num_tasks, hidden_dim=256)
# model.load_pretrained('ckpt.pt')

number of parameters: 6.72M


In [14]:
model_config

ModelConfig({'model_name': 'JointformerV2', 'embedding_dim': 256, 'embedding_hidden_dim': 1024, 'num_heads': 8, 'num_local_heads': 8, 'head_dim': 32, 'num_layers': 8, 'bias': False, 'attention_dropout': None, 'feed_forward_dropout': None, 'prediction_dropout': None, 'layer_norm_eps': 1e-05, 'vocab_size': 596, 'max_seq_len': 128, 'prediction_task_type': None, 'num_prediction_tasks': None, 'num_physchem_tasks': 200, 'pretrained_filepath': None, 'predictor_hidden_size': None, 'predictor_dropout': None, 'predictor_num_heads': None, 'prediction_hidden_dim': 256, 'set_separate_task_tokens': None, 'flash_attention': True, 'dropout': 0.1, 'lambda_hparam': None, 'block_size': 128, 'n_embd': 256, 'n_layer': 8, 'n_head': 8, 'num_props': 200})

In [15]:

trainer_config = TrainerConfig.from_config_file(PATH_TO_TRAINER_CONFIG)
trainer = Trainer(
    config=trainer_config,
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    test_dataset=val_dataset,
    tokenizer=tokenizer
    )

using fused AdamW: True


In [27]:
trainer._init_data_loaders()

batch = next(iter(trainer.train_loader))
batch.to('cuda')
model.to('cuda')
batch['task'] = 'physchem'
outputs = model(**batch)

In [24]:
inputs = tokenizer([train_dataset[idx] for idx in range(2)], task='generation')
inputs = inputs.to('cuda')
model_outputs = model.get_loss(**inputs)

torch.Size([2, 127])
torch.Size([2, 127, 596])


In [19]:
outputs

{'token_embeddings': tensor([[[-4.3458e-02, -6.9468e-04,  4.2362e-02,  ...,  2.1533e-02,
           -6.0646e-02, -3.4861e-02],
          [ 6.0578e-03, -4.7434e-02,  1.6648e-02,  ..., -7.4538e-03,
            3.6801e-02,  1.7726e-02],
          [-1.5884e-02, -8.6445e-03, -5.5961e-02,  ...,  1.6495e-02,
            1.9452e-02, -1.6902e-02],
          ...,
          [ 2.8788e-02,  1.5539e-05,  5.0084e-02,  ..., -1.2208e-02,
           -5.4630e-03,  1.4062e-02],
          [ 1.9007e-03,  8.4405e-03,  1.3463e-02,  ..., -5.4869e-03,
            2.1876e-02,  8.2266e-03],
          [ 3.1826e-02,  2.3824e-03, -1.7455e-02,  ..., -9.0177e-03,
            8.7530e-03,  4.6899e-02]],
 
         [[-4.3458e-02, -6.9468e-04,  4.2362e-02,  ...,  2.1533e-02,
           -6.0646e-02, -3.4861e-02],
          [ 6.0578e-03, -4.7434e-02,  1.6648e-02,  ..., -7.4538e-03,
            3.6801e-02,  1.7726e-02],
          [ 3.9049e-02,  3.8860e-02, -1.8694e-02,  ..., -3.8230e-02,
            2.5005e-02, -1.6724e-02],