In [None]:
def test_generation(model, tokenizer):
    """
    Test various components of the generation process.
    
    Args:
        model: The Hyformer model
        tokenizer: The tokenizer used with the model
    """
    # Move model to cuda if available
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model.eval()
    
    # Create a simple input sequence
    prefix = "CC"  # Simple SMILES string
    input_ids = tokenizer(prefix, task="lm")['input_ids']
    prefix_input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
    
    print("1. Testing basic generation:")
    print("-" * 50)
    try:
        outputs = model.generate(
            prefix_input_ids=prefix_input_ids,
            num_tokens_to_generate=5,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            temperature=1.0,
            top_k=25
        )
        print("✓ Basic generation successful")
        print(f"Input SMILES: {prefix}")
        print(f"Generated SMILES: {tokenizer.decode(outputs[0])}")
    except Exception as e:
        print(f"✗ Basic generation failed: {str(e)}")
    
    print("\n2. Testing EOS token handling:")
    print("-" * 50)
    try:
        # Force early EOS by using low temperature
        outputs = model.generate(
            prefix_input_ids=prefix_input_ids,
            num_tokens_to_generate=10,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            temperature=0.1,
            top_k=25
        )
        # Check if output contains EOS and PAD tokens
        has_eos = tokenizer.eos_token_id in outputs[0]
        has_pad = tokenizer.pad_token_id in outputs[0]
        print(f"✓ EOS token present: {has_eos}")
        print(f"✓ PAD token present: {has_pad}")
        print(f"Generated sequence: {[int(x) for x in outputs[0]]}")
    except Exception as e:
        print(f"✗ EOS token test failed: {str(e)}")
    
    print("\n3. Testing KV caching:")
    print("-" * 50)
    try:
        # Generate with and without caching and compare time
        import time
        
        start_time = time.time()
        _ = model.generate(
            prefix_input_ids=prefix_input_ids,
            num_tokens_to_generate=10,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=False
        )
        no_cache_time = time.time() - start_time
        
        start_time = time.time()
        _ = model.generate(
            prefix_input_ids=prefix_input_ids,
            num_tokens_to_generate=10,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=True
        )
        cache_time = time.time() - start_time
        
        print(f"✓ Generation time without cache: {no_cache_time:.4f}s")
        print(f"✓ Generation time with cache: {cache_time:.4f}s")
        print(f"✓ Speedup from caching: {no_cache_time/cache_time:.2f}x")
    except Exception as e:
        print(f"✗ KV caching test failed: {str(e)}")
    
    print("\n4. Testing batch generation:")
    print("-" * 50)
    try:
        # Create a batch of 3 sequences
        batch_input_ids = torch.cat([prefix_input_ids] * 3, dim=0)
        outputs = model.generate(
            prefix_input_ids=batch_input_ids,
            num_tokens_to_generate=5,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )
        print(f"✓ Batch generation successful")
        print(f"Batch size: {outputs.shape[0]}")
        for i, seq in enumerate(outputs):
            print(f"Sequence {i}: {tokenizer.decode(seq)}")
    except Exception as e:
        print(f"✗ Batch generation failed: {str(e)}")
    
    print("\n5. Testing sampling parameters:")
    print("-" * 50)
    try:
        # Test different temperatures
        temps = [0.1, 1.0, 2.0]
        for temp in temps:
            outputs = model.generate(
                prefix_input_ids=prefix_input_ids,
                num_tokens_to_generate=5,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                temperature=temp,
                top_k=25
            )
            print(f"Temperature {temp}: {tokenizer.decode(outputs[0])}")
    except Exception as e:
        print(f"✗ Sampling parameter test failed: {str(e)}")

# Usage example:
test_generation(model, tokenizer)