In [1]:
# Imports

import os

from hyformer.configs.dataset import DatasetConfig
from hyformer.configs.tokenizer import TokenizerConfig
from hyformer.configs.model import ModelConfig
from hyformer.configs.trainer import TrainerConfig

from hyformer.utils.datasets.auto import AutoDataset
from hyformer.utils.tokenizers.auto import AutoTokenizer
from hyformer.models.auto import AutoModel
# from hyformer.trainers.trainer import Trainer

from hyformer.utils.runtime import set_seed

from tqdm.auto import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
# Set working directory of the project

REPOSITORY_DIR = '/home/aih/adam.izdebski/projects/hyformer'
os.chdir(REPOSITORY_DIR)

In [3]:
# Set seed for reproducibility

set_seed(1337)


In [4]:
# Configs

DATA_DIR = '/lustre/groups/aih/jointformer/icml25/data'
OUTPUT_DIR = '/lustre/groups/aih/jointformer/icml25/results'

PATH_TO_DATASET_CONFIG   = 'configs/datasets/guacamol/config.json'
PATH_TO_TOKENIZER_CONFIG = 'configs/tokenizers/smiles/config.json'
PATH_TO_MODEL_CONFIG = 'configs/models/hyformer_tiny/config.json'
PATH_TO_TRAINER_CONFIG = 'configs/trainers/pretrain/config.json'

In [5]:
# Load configs

dataset_config = DatasetConfig.from_config_path(PATH_TO_DATASET_CONFIG)
tokenizer_config = TokenizerConfig.from_config_path(PATH_TO_TOKENIZER_CONFIG)
model_config = ModelConfig.from_config_path(PATH_TO_MODEL_CONFIG)
trainer_config = TrainerConfig.from_config_path(PATH_TO_TRAINER_CONFIG)


In [6]:
# Load datasets

train_dataset = AutoDataset.from_config(dataset_config, root=DATA_DIR, split='train')
val_dataset = AutoDataset.from_config(dataset_config, root=DATA_DIR, split='val')
test_dataset = AutoDataset.from_config(dataset_config, root=DATA_DIR, split='test')


In [7]:
# Load tokenizer

tokenizer = AutoTokenizer.from_config(tokenizer_config)

In [8]:
not_found = 0

for idx in tqdm(range(len(train_dataset))):
    smiles = train_dataset[idx]['data']
    input_ids = tokenizer(smiles, task="lm")['input_ids']
    if smiles != tokenizer.decode(input_ids):
        print("-"*100)
        print(f"IDX: {idx}")
        print(f"UNK token found: {tokenizer.unk_token_id in input_ids}")
        print(smiles)
        print(tokenizer.decode(tokenizer(smiles, task="lm")['input_ids']))
        not_found += 1
        print("-"*100)   
print(not_found)



  0%|          | 0/1273104 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [23]:
# Load model
# assert attention mask is boolean

model = AutoModel.from_config(model_config)


In [11]:
model.init_prediction_head(num_prediction_tasks=1, prediction_task_type='classification')

In [12]:
import torch 

In [26]:
model.to('cuda')

Hyformer(
  (token_embedding): Embedding(596, 512)
  (layers): ModuleList(
    (0): TransformerLayer(
      (attention_layer): Attention(
        (q_proj): Linear(in_features=512, out_features=512, bias=False)
        (k_proj): Linear(in_features=512, out_features=512, bias=False)
        (v_proj): Linear(in_features=512, out_features=512, bias=False)
        (out): Linear(in_features=512, out_features=512, bias=False)
        (relative_embedding): RotaryEmbedding()
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=512, out_features=2048, bias=False)
        (w3): Linear(in_features=512, out_features=2048, bias=False)
        (w2): Linear(in_features=2048, out_features=512, bias=False)
      )
      (attention_layer_normalization): RMSNorm()
      (feed_forward_normalization): RMSNorm()
    )
  )
  (layer_norm): RMSNorm()
  (lm_head): Linear(in_features=512, out_features=596, bias=False)
  (mlm_head): Linear(in_features=512, out_features=596, bias=False)
)

In [27]:
samples = model.generate(
    idx=torch.tensor([[tokenizer.task_token_id('lm'), tokenizer.bos_token_id]], dtype=torch.long, device='cuda'),
    max_new_tokens=10,
    temperature=1.0,
    top_k=None,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id
)

IndexError: list index out of range

In [None]:
def test_generation(model, tokenizer):
    """
    Test various components of the generation process.
    
    Args:
        model: The Hyformer model
        tokenizer: The tokenizer used with the model
    """
    # Move model to cuda if available
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model.eval()
    
    # Create a simple input sequence
    prefix = "CC"  # Simple SMILES string
    input_ids = tokenizer(prefix, task="lm")['input_ids']
    prefix_input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
    
    print("1. Testing basic generation:")
    print("-" * 50)
    try:
        outputs = model.generate(
            prefix_input_ids=prefix_input_ids,
            num_tokens_to_generate=5,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            temperature=1.0,
            top_k=25
        )
        print("✓ Basic generation successful")
        print(f"Input SMILES: {prefix}")
        print(f"Generated SMILES: {tokenizer.decode(outputs[0])}")
    except Exception as e:
        print(f"✗ Basic generation failed: {str(e)}")
    
    print("\n2. Testing EOS token handling:")
    print("-" * 50)
    try:
        # Force early EOS by using low temperature
        outputs = model.generate(
            prefix_input_ids=prefix_input_ids,
            num_tokens_to_generate=10,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            temperature=0.1,
            top_k=25
        )
        # Check if output contains EOS and PAD tokens
        has_eos = tokenizer.eos_token_id in outputs[0]
        has_pad = tokenizer.pad_token_id in outputs[0]
        print(f"✓ EOS token present: {has_eos}")
        print(f"✓ PAD token present: {has_pad}")
        print(f"Generated sequence: {[int(x) for x in outputs[0]]}")
    except Exception as e:
        print(f"✗ EOS token test failed: {str(e)}")
    
    print("\n3. Testing KV caching:")
    print("-" * 50)
    try:
        # Generate with and without caching and compare time
        import time
        
        start_time = time.time()
        _ = model.generate(
            prefix_input_ids=prefix_input_ids,
            num_tokens_to_generate=10,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=False
        )
        no_cache_time = time.time() - start_time
        
        start_time = time.time()
        _ = model.generate(
            prefix_input_ids=prefix_input_ids,
            num_tokens_to_generate=10,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=True
        )
        cache_time = time.time() - start_time
        
        print(f"✓ Generation time without cache: {no_cache_time:.4f}s")
        print(f"✓ Generation time with cache: {cache_time:.4f}s")
        print(f"✓ Speedup from caching: {no_cache_time/cache_time:.2f}x")
    except Exception as e:
        print(f"✗ KV caching test failed: {str(e)}")
    
    print("\n4. Testing batch generation:")
    print("-" * 50)
    try:
        # Create a batch of 3 sequences
        batch_input_ids = torch.cat([prefix_input_ids] * 3, dim=0)
        outputs = model.generate(
            prefix_input_ids=batch_input_ids,
            num_tokens_to_generate=5,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )
        print(f"✓ Batch generation successful")
        print(f"Batch size: {outputs.shape[0]}")
        for i, seq in enumerate(outputs):
            print(f"Sequence {i}: {tokenizer.decode(seq)}")
    except Exception as e:
        print(f"✗ Batch generation failed: {str(e)}")
    
    print("\n5. Testing sampling parameters:")
    print("-" * 50)
    try:
        # Test different temperatures
        temps = [0.1, 1.0, 2.0]
        for temp in temps:
            outputs = model.generate(
                prefix_input_ids=prefix_input_ids,
                num_tokens_to_generate=5,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                temperature=temp,
                top_k=25
            )
            print(f"Temperature {temp}: {tokenizer.decode(outputs[0])}")
    except Exception as e:
        print(f"✗ Sampling parameter test failed: {str(e)}")

# Usage example:
test_generation(model, tokenizer)