#### TODOs:
    - extract Guacamol data using not pickling
    - run training and debug sucesfully 
    - get all tokenn
    - get a proper env - env.yml (with minimal requirements)
    - reproduce Guacamol benchmark for serious
    - actually, I need to redo the splits for Fibrosis
    - download UniMol pre-train data or other pre-train data (GuacaMol) and save it using: smiles_array = np.array(smiles_list, dtype='<U50'), to disable pickle. 
    - save data with numpy_version=np.__version__

In [39]:
# Imports

import numpy as np

from hyformer.configs.dataset import DatasetConfig
from hyformer.configs.tokenizer import TokenizerConfig
from hyformer.configs.model import ModelConfig
from hyformer.configs.trainer import TrainerConfig

from hyformer.utils.datasets.sequence import SequenceDataset
from hyformer.utils.datasets.auto import AutoDataset
from hyformer.utils.tokenizers.auto import AutoTokenizer
from hyformer.models.auto import AutoModel
from hyformer.trainers.trainer import Trainer

# auxiliary 
import os
import torch
import random

# reload magic
%reload_ext autoreload
%autoreload 2


In [40]:
import logging
import sys
from IPython.display import HTML, display

# Create a custom handler for notebook display
class NotebookHandler(logging.Handler):
    def __init__(self, level=logging.INFO):
        super().__init__(level)
        self.logs = []
        
    def emit(self, record):
        log_entry = self.format(record)
        self.logs.append(log_entry)
        
        # Create different CSS classes based on log level
        level_css = {
            'DEBUG': 'color: #6c757d;',  # gray
            'INFO': 'color: #0d6efd;',   # blue
            'WARNING': 'color: #ffc107; font-weight: bold;',  # yellow
            'ERROR': 'color: #dc3545; font-weight: bold;',    # red
            'CRITICAL': 'color: #fff; background-color: #dc3545; font-weight: bold; padding: 2px 5px;'  # white on red
        }
        
        css = level_css.get(record.levelname, '')
        display(HTML(f'<pre style="{css}">{log_entry}</pre>'))
    
    def get_logs(self):
        return self.logs

# Configure root logger and Trainer's logger
def setup_notebook_logging(level=logging.INFO):
    # Clear any existing handlers
    root_logger = logging.getLogger()
    for handler in root_logger.handlers[:]:
        root_logger.removeHandler(handler)
    
    # Configure formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    
    # Setup notebook handler
    notebook_handler = NotebookHandler(level=level)
    notebook_handler.setFormatter(formatter)
    root_logger.addHandler(notebook_handler)
    root_logger.setLevel(level)
    
    # Also add a handler for the hyformer.trainers console logger
    trainer_logger = logging.getLogger('hyformer.trainers')
    trainer_logger.addHandler(notebook_handler)
    trainer_logger.setLevel(level)
    
    return notebook_handler

# Call this function to set up logging
notebook_logs = setup_notebook_logging(logging.INFO)

# Test the logger
logging.info("Logging is set up and working!")

In [42]:
# Device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the notebook output directory to project root

REPOSITORY_DIR = '/home/aih/adam.izdebski/projects/hyformer'
os.chdir(REPOSITORY_DIR)


In [43]:
# Constants

OUT_DIR = "results/notebooks/test"
DATA_DIR = "/lustre/groups/aih/hyformer/data"

DATASET_CONFIG_FILEPATH = "configs/datasets/fibrosis/standard_smiles/random_transfer/config.json"
TOKENIZER_CONFIG_FILEPATH = "configs/tokenizers/smiles/config.json"
MODEL_CONFIG_FILEPATH = "configs/models/hyformer_tiny/config.json"
TRAINER_CONFIG_FILEPATH = "configs/trainers/pretrain_lm_mlm_physchem_debug/config.json"

In [44]:
# Load exemplary SMILES with random labels for testing
 
smiles_list = [
    "Cc1nc[nH]c1C1CC(=O)Nc2c1c(C)nn2C1CCCCC1",
    "CCCCNC(=O)NS(=O)(=O)c1ccc(C)cc1",
    "COc1ccc(Nc2c3c4c(c(Br)ccc4n(C)c2=O)C(=O)c2ccccc2-3)cc1",
    "Cn1c(=O)c2c(ncn2CC2OCCO2)n(C)c1=O",
    "NC12CC3CC(C1)CC(n1cncn1)(C3)C2",
    "NC(=O)NN=Cc1ccc([N+](=O)[O-])o1",
    "CC1=C(C(=O)OC(C)C)C(c2ccccn2)C2=C(CC(c3ccccc3)CC2=O)N1",
    "CCc1c2c(nc3ccc(OC(=O)N4CCC(N5CCCCC5)CC4)cc13)-c1cc3c(c(=O)n1C2)COC(=O)C3(O)CC",
    "Cc1nc(CN(C)C(=O)C2CC=CCC2)no1",
    "COc1ccnc(C[S+]([O-])c2nc3ccc(-n4cccc4)cc3[nH]2)c1C",
    "CCCCC(CC)COC(=O)C=Cc1ccc(OC)cc1",
    "COCCOC(=O)C1=C(C)NC(C)=C(C(=O)OC(C)C)C1c1cccc([N+](=O)[O-])c1",
    "Cc1ccc2c(c1)N(CCO)C(=Cc1ccc[n+](C)c1)S2",
    "O=C(c1ccccn1)C1CCCN(C2CCC3(CCNCC3)CC2)C1",
    "CCC(=O)OC(OP(=O)(CCCCc1ccccc1)CC(=O)N1CC(C2CCCCC2)CC1C(=O)O)C(C)C",
    "CS(=O)(=O)Nc1ccc([N+](=O)[O-])cc1Oc1ccccc1"
]

target_list = [random.randint(0, 1) for _ in smiles_list]


In [49]:
# Dataset

# dataset_config = DatasetConfig.from_config_file(DATASET_CONFIG_FILEPATH)
# dataset = AutoDataset.from_config(dataset_config, split="test", root=DATA_DIR)

train_dataset = SequenceDataset(data=smiles_list, target=np.array(target_list, dtype=np.int32))
val_dataset = SequenceDataset(data=smiles_list, target=np.array(target_list, dtype=np.int32))
test_dataset = SequenceDataset(data=smiles_list, target=np.array(target_list, dtype=np.int32))

In [50]:
# Tokenizer

tokenizer_config = TokenizerConfig.from_config_file(TOKENIZER_CONFIG_FILEPATH)
tokenizer = AutoTokenizer.from_config(tokenizer_config)



In [51]:
# get vocab size

vocab_size = len(tokenizer)
print(f"Vocab size: {vocab_size}")


Vocab size: 511


In [52]:
# Model

model_config = ModelConfig.from_config_file(MODEL_CONFIG_FILEPATH)
model = AutoModel.from_config(model_config)
model.to(device)




Hyformer(
  (token_embedding): Embedding(511, 16)
  (layers): ModuleList(
    (0): TransformerLayer(
      (attention_layer): Attention(
        (q_proj): Linear(in_features=16, out_features=16, bias=False)
        (k_proj): Linear(in_features=16, out_features=16, bias=False)
        (v_proj): Linear(in_features=16, out_features=16, bias=False)
        (out): Linear(in_features=16, out_features=16, bias=False)
        (relative_embedding): RotaryEmbedding()
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=16, out_features=256, bias=False)
        (w3): Linear(in_features=16, out_features=256, bias=False)
        (w2): Linear(in_features=256, out_features=16, bias=False)
      )
      (attention_layer_normalization): RMSNorm()
      (feed_forward_normalization): RMSNorm()
    )
  )
  (layer_norm): RMSNorm()
  (lm_head): Linear(in_features=16, out_features=511, bias=False)
  (mlm_head): Linear(in_features=16, out_features=511, bias=False)
)

In [53]:
if not os.path.exists(OUT_DIR):
    os.makedirs(OUT_DIR)

In [54]:
# Trainer 

trainer_config = TrainerConfig.from_config_file(TRAINER_CONFIG_FILEPATH)
trainer = Trainer.from_config(config=trainer_config, model=model, tokenizer=tokenizer, device=device, out_dir=OUT_DIR)


In [55]:
trainer.train(train_dataset=train_dataset, val_dataset=val_dataset)

Exception in thread Exception in thread QueueFeederThread:
QueueFeederThread:
Traceback (most recent call last):
  File "/lustre/groups/aih/hyformer/env/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/lustre/groups/aih/hyformer/env/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
Traceback (most recent call last):
  File "/lustre/groups/aih/hyformer/env/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/lustre/groups/aih/hyformer/env/lib/python3.10/multiprocessing/connection.py", line 177, in close
    reader_close()
  File "/lustre/groups/aih/hyformer/env/lib/python3.10/multiprocessing/connection.py", line 177, in close
    reader_close()
  File "/lustre/groups/aih/hyformer/env/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/lustre/groups/aih/hyformer/env/lib/python3.10/multiprocessing/connecti



In [33]:
# Train

test_metric = trainer.test(test_dataset, 'perplexity')


AttributeError: 'Trainer' object has no attribute 'test'

In [34]:
train_loader = trainer.create_loader(train_dataset, shuffle=True, tasks={'lm': 1.0})

In [35]:
for batch_idx, model_input in enumerate(train_loader):
    # Move batch to device - inputs is already a dict of tensors
    model_input = {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in model_input.items()}
    model_output = model(**model_input)
    
    break

In [36]:
logits = model_output['logits'].detach().cpu()
mask = model_output['attention_mask'].unsqueeze(-1).expand_as(logits).detach().cpu()
logits[~mask] = -torch.inf


In [38]:
mask.bool()

tensor([[[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [False, False, False,  ..., False, False, False],
         [

In [16]:
logits[~mask] = -torch.inf


In [18]:
logits.numpy()

array([[[-1.136662  ,  1.0592679 ,  0.3977642 , ...,  3.5987418 ,
          0.5068621 ,  2.3434474 ],
        [-1.3086665 ,  0.4904939 ,  0.73951775, ..., -0.09715027,
          0.10324742,  0.677319  ],
        [-1.0299996 , -0.40638733,  0.49864063, ...,  0.8964067 ,
         -0.8524088 ,  1.2367059 ],
        ...,
        [-0.17828667, -0.6989492 , -1.1391791 , ..., -1.3364989 ,
          1.4154167 , -1.4605953 ],
        [-0.17870045, -0.6396375 , -1.1487352 , ..., -1.2919847 ,
          1.4239221 , -1.404258  ],
        [-0.18010461, -0.567212  , -1.155138  , ..., -1.2191597 ,
          1.4638362 , -1.3483474 ]],

       [[-1.136662  ,  1.0592679 ,  0.3977642 , ...,  3.5987418 ,
          0.5068621 ,  2.3434474 ],
        [-1.3086665 ,  0.4904939 ,  0.73951775, ..., -0.09715027,
          0.10324742,  0.677319  ],
        [-1.961201  ,  0.630399  ,  0.77966464, ...,  0.20175445,
          0.9396163 , -0.0604695 ],
        ...,
        [-0.48047963, -0.53841233, -0.92524254, ...,  

In [19]:
from scipy.special import log_softmax
# ignore warnings
import warnings
warnings.filterwarnings('ignore')

log_probs = log_softmax(logits, axis=-1)



  out = tmp - out


In [None]:
log_probs.max(axis=-1).shape

array([[-3.1324854, -3.318623 , -1.8815334, -1.8131673, -3.3876677,
        -3.1246696, -3.191782 , -3.0979633, -3.0450742, -2.2308354,
        -2.330618 , -2.472324 , -3.6515033, -3.50948  , -3.3139763,
        -2.6987703, -3.2547076, -3.4764345, -3.1284685, -3.9798079,
        -3.127894 , -3.4546978, -1.6441965, -3.5386233, -2.89884  ,
        -2.824465 , -2.4335928, -2.2494895, -2.1354132, -3.7622044,
        -2.5043337, -3.1157641, -4.0375876, -2.037845 , -3.5541823,
        -3.8744311, -3.2608833, -3.1368597, -3.5858393, -3.9428024,
        -3.9238508, -3.8969798, -3.7878106, -3.6404438, -3.5714214,
        -3.6126592, -3.6882749, -3.63384  , -3.5659509, -3.4978185,
        -3.4211936, -3.3906994, -3.404259 , -3.4438875, -3.451629 ,
        -3.4272156, -3.4222374, -3.4365067, -3.4236848, -3.3792446,
        -3.3100119, -3.207237 , -3.1016018, -3.0563757, -3.0817776,
        -3.1521232, -3.2167647, -3.2216198, -3.1543972, -3.0408382,
        -2.9479885, -2.9263487],
       [-3.1324

: 

: 

In [None]:
log_probs = F.log_softmax(logits, dim=-1).max(dim=-1).values



if base == 'exp':
    return torch.exp(-log_probs.nanmean(dim=-1))
elif isinstance(base, float) or isinstance(base, int):
    return base ** (-log_probs.nanmean(dim=-1))
else:
    raise ValueError("Invalid base type. Choose from 'exp' or float or int.")


In [12]:
trainer.evaluate()

AttributeError: 'Trainer' object has no attribute 'training'

In [29]:
tokenizer.all_special_tokens

['<s>', '</s>', '<pad>', '<unk>', '<mask>', '<lm>', '<cls>', '<mlm>']

In [30]:
tokenizer.all_special_ids

[503, 504, 505, 506, 507, 508, 509, 510]