# Debug training

In [2]:
# Imports

import os, logging, argparse, sys

import torch

from hyformer.configs.dataset import DatasetConfig
from hyformer.configs.tokenizer import TokenizerConfig
from hyformer.configs.model import ModelConfig
from hyformer.configs.trainer import TrainerConfig


from hyformer.utils.datasets.auto import AutoDataset
from hyformer.utils.tokenizers.auto import AutoTokenizer
from hyformer.models.auto import AutoModel


from hyformer.trainers.trainer import Trainer

from hyformer.utils.reproducibility import set_seed

# autoreload
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
DATA_DIR = "/lustre/groups/aih/hyformer/data"

DATASET_CONFIG_PATH = "configs/datasets/guacamol/config.json"
TOKENIZER_CONFIG_PATH = "configs/tokenizers/smiles/guacamol/config.json"
MODEL_CONFIG_PATH = "configs/models/hyformer_small/config.json"
TRAINER_CONFIG_PATH = "configs/trainers/distribution_learning/guacamol/lm/config.json"


In [4]:
# Load configurations
dataset_config = DatasetConfig.from_config_filepath(DATASET_CONFIG_PATH)
tokenizer_config = TokenizerConfig.from_config_filepath(TOKENIZER_CONFIG_PATH)
model_config = ModelConfig.from_config_filepath(MODEL_CONFIG_PATH)
trainer_config = TrainerConfig.from_config_filepath(TRAINER_CONFIG_PATH)


In [5]:
# Initialize
train_dataset = AutoDataset.from_config(dataset_config, split='train', root=DATA_DIR)
val_dataset = AutoDataset.from_config(dataset_config, split='val', root=DATA_DIR)


In [7]:
tokenizer = AutoTokenizer.from_config(tokenizer_config)


In [None]:

# model = AutoModel.from_config(model_config)
   

In [8]:
# Determine the device
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')


In [15]:
# Initialize trainer
trainer = Trainer(
    config=trainer_config,
    model=model,
    tokenizer=tokenizer,
    device=device,
    )



In [9]:
samples = [train_dataset[i]['data'] for i in range(2)]

In [10]:
samples


['CCC(C)(C)Br', 'CCCN(CCc1cccc(-c2ccccc2)c1)C(=O)C1OC(C(=O)O)=CC(N)C1NC(C)=O']

In [15]:
tokenizer(samples, task='lm')

{'input_ids': [[102, 19, 19, 19, 4, 19, 5, 4, 19, 5, 18, 103],
  [102,
   19,
   19,
   19,
   23,
   4,
   19,
   19,
   97,
   7,
   97,
   97,
   97,
   97,
   4,
   6,
   97,
   8,
   97,
   97,
   97,
   97,
   97,
   8,
   5,
   97,
   7,
   5,
   19,
   4,
   16,
   24,
   5,
   19,
   7,
   24,
   19,
   4,
   19,
   4,
   16,
   24,
   5,
   24,
   5,
   16,
   19,
   19,
   4,
   23,
   5,
   19,
   7,
   23,
   19,
   4,
   19,
   5,
   16,
   24,
   103]],
 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1]]}

In [26]:
tokenizer.decode(torch.tensor(tokenizer(samples, task='lm')['input_ids'][0]))

'CCC(C)(C)Br'

In [16]:
trainer_loader = trainer.create_loader(train_dataset, shuffle=True, tasks=trainer.config.tasks)



In [17]:
batch = next(iter(trainer_loader))



In [24]:
tokenizer.decode(batch['input_ids'][5])

'O=C(O)C1C2CC=CC2c2cc(Cl)cc3c2N1CC1CC=CC31'

In [33]:
batch['input_ids'][5]

tensor([508, 503,  29,  21,  24,   6,  29,   7,  24,  12,  24,  13,  24,  24,
         21,  24,  24,  13, 498,  13, 498, 498,   6,  25,   7, 498, 498,  14,
        498,  13,  28,  12,  24,  24,  12,  24,  24,  21,  24,  24,  14,  12,
        504, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505,
        505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505,
        505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505,
        505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505,
        505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505,
        505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505, 505,
        505, 505])

In [34]:
batch['input_labels'][5]

tensor([-100, -100,   29,   21,   24,    6,   29,    7,   24,   12,   24,   13,
          24,   24,   21,   24,   24,   13,  498,   13,  498,  498,    6,   25,
           7,  498,  498,   14,  498,   13,   28,   12,   24,   24,   12,   24,
          24,   21,   24,   24,   14,   12,  504, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100])

In [29]:
batch['attention_mask'][5]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

In [1]:
import torch 

from hyformer.models.layers.rotary import RotaryEmbedding


In [16]:
import torch
from hyformer.models.layers.rotary import RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

rotary_emb = RotaryEmbedding(dim = 32)

# mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc)

q = torch.randn(1, 2, 8, 64) # queries - (batch, heads, seq len, dimension of head)
k = torch.randn(1, 2, 8, 64) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

q = rotary_emb.rotate_queries_or_keys(q)
k = rotary_emb.rotate_queries_or_keys(k)

# then do your attention with your queries (q) and keys (k) as usual

In [22]:
seq_idx = 2
q = torch.randn(1, 2, 8, 64) # queries - (batch, heads, seq len, dimension of head)
k = torch.randn(1, 2, 8, 64) # keys

q = q[:, :, [seq_idx], :]
k[:, :, seq_idx+1:, :] = 0




In [None]:
q_rotated, k_rotated = rotary_emb.rotate_queries_with_cached_keys(q, k) # pass all cached keys

In [43]:
q_double_rotated, k_double_rotated = rotary_emb.rotate_queries_with_cached_keys(q_rotated, k_rotated) # pass all cached keys


In [48]:

print(torch.allclose(k_double_rotated[:, :, 0, :], k_rotated[:, :, 0, :]))

True


In [51]:
k_rotated[:, :, 1, :]


tensor([[[ 1.8349,  0.1951, -1.3999, -0.0461,  1.8644,  0.8243,  0.5217,
           0.0165, -0.1531, -1.5532,  0.0227, -0.6695, -0.4033, -1.7691,
           0.4640,  1.5651,  0.0600,  1.5951,  0.9842,  0.4906,  0.9918,
          -0.5005, -3.1655,  0.0134,  0.8975, -1.4514,  0.1484, -0.2623,
           0.6115,  2.1162,  0.5849, -0.7527,  0.1158, -0.4192,  0.6878,
           0.4351, -0.2205, -1.1369,  0.7604,  1.0182,  0.6992, -0.6007,
          -1.6325,  0.0287,  0.5296, -1.2345, -0.7244, -0.5330, -1.7051,
           1.2483,  0.6185,  0.1223,  1.1791, -0.2623, -0.7074, -0.8204,
          -0.5255, -0.3207, -0.9353, -0.3042,  0.5654,  0.2660,  0.8855,
          -0.8075],
         [-0.0882,  0.3458,  0.3276, -0.4652, -1.5064, -0.8110, -1.6426,
          -0.4660,  0.8792, -2.8825, -1.5115, -0.6945,  1.9446, -1.5879,
          -0.0137, -0.6306, -0.5265,  0.1023,  1.1742, -0.1347,  0.8002,
          -0.5413,  0.4216,  1.1626, -0.4917,  0.8571,  1.1744, -0.5008,
          -0.1364, -0.2443,  3.

In [52]:
k_double_rotated[:, :, 1, :]

tensor([[[ 8.2724e-01,  1.6494e+00, -1.1598e+00, -7.8542e-01,  1.5156e+00,
           1.3633e+00,  5.1060e-01,  1.0855e-01,  2.7415e-03, -1.5607e+00,
           6.0266e-02, -6.6716e-01, -3.4714e-01, -1.7810e+00,  4.3605e-01,
           1.5731e+00,  4.4025e-02,  1.5956e+00,  9.8138e-01,  4.9612e-01,
           9.9339e-01, -4.9738e-01, -3.1655e+00,  7.7229e-03,  8.9898e-01,
          -1.4505e+00,  1.4860e-01, -2.6223e-01,  6.1085e-01,  2.1164e+00,
           5.8504e-01, -7.5256e-01,  1.1581e-01, -4.1919e-01,  6.8779e-01,
           4.3509e-01, -2.2048e-01, -1.1369e+00,  7.6039e-01,  1.0182e+00,
           6.9916e-01, -6.0072e-01, -1.6325e+00,  2.8681e-02,  5.2958e-01,
          -1.2345e+00, -7.2438e-01, -5.3301e-01, -1.7051e+00,  1.2483e+00,
           6.1851e-01,  1.2232e-01,  1.1791e+00, -2.6227e-01, -7.0736e-01,
          -8.2045e-01, -5.2549e-01, -3.2071e-01, -9.3528e-01, -3.0422e-01,
           5.6540e-01,  2.6597e-01,  8.8553e-01, -8.0746e-01],
         [-3.3865e-01,  1.1264e-01,  