In [13]:
import torch
import math
from transformers import PretrainedConfig
from iDistilbert import iMultiHeadSelfAttention  # Replace 'your_module' with the actual module name

def test_imultihead_attention():
    # Set up a configuration
    config = PretrainedConfig()
    config.n_heads = 8
    config.dim = 512
    config.attention_dropout = 0.1
    config.distance_metric = 'manhattan_distance'
    config.activation_function = 'relu'
    config.signed_inhibitor = True
    config.alpha = 0
    config.center = False

    # Initialize the attention module
    attention = iMultiHeadSelfAttention(config)

    # Create sample inputs
    batch_size = 2
    seq_length = 10
    dim = config.dim
    query = torch.randn(batch_size, seq_length, dim)
    key = torch.randn(batch_size, seq_length, dim)
    value = torch.randn(batch_size, seq_length, dim)
    mask = torch.ones(batch_size, seq_length)

    def shape(x):
        return x.view(batch_size, -1, config.n_heads, dim // config.n_heads).transpose(1, 2)

    q = shape(attention.q_lin(query))
    k = shape(attention.k_lin(key))
    v = shape(attention.v_lin(value))

    # Test forward pass
    output = attention(query, key, value, mask)

    # Basic checks
    assert isinstance(output, tuple), "Output should be a tuple"
    assert len(output) == 1, "Output tuple should have length 1 when output_attentions is False"
    context = output[0]
    assert context.shape == (batch_size, seq_length, dim), f"Expected shape {(batch_size, seq_length, dim)}, got {context.shape}"

    # Test with output_attentions=True
    output_with_attentions = attention(query, key, value, mask, output_attentions=True)
    assert len(output_with_attentions) == 2, "Output tuple should have length 2 when output_attentions is True"
    context, attentions = output_with_attentions
    assert attentions.shape == (batch_size, config.n_heads, seq_length, seq_length), \
        f"Expected attention shape {(batch_size, config.n_heads, seq_length, seq_length)}, got {attentions.shape}"

    # Test Manhattan distance calculation
    q = attention.q_lin(query).view(batch_size, seq_length, config.n_heads, -1).transpose(1, 2)
    k = attention.k_lin(key).view(batch_size, seq_length, config.n_heads, -1).transpose(1, 2)
    manhattan_dist = torch.cdist(q, k, p=1) / math.sqrt(dim // config.n_heads)
    print("manhattan:", manhattan_dist, "attention:", attentions)
    assert torch.allclose(manhattan_dist, attentions, atol=1e-4), "Manhattan distance calculation is incorrect"

    # Test alpha shift
    if config.alpha > 0:
        assert torch.all(attentions <= manhattan_dist - config.alpha), "Alpha shift is not applied correctly"

    # Test centering
    if config.center: 
        centered_mean = torch.mean(attentions, dim=-1, keepdim=True)
        assert torch.allclose(centered_mean, torch.zeros_like(centered_mean), atol=1e-5), "Centering is not applied correctly"

    # Test signed inhibitor
    v = attention.v_lin(value).view(batch_size, seq_length, config.n_heads, -1).transpose(1, 2)
    v_t = v.transpose(-1, -2)
    pos_v = torch.nn.functional.relu(v_t)
    neg_v = -torch.nn.functional.relu(-v_t)
    v_sum = torch.sum(v, dim=-2, keepdim=True)
    dist1 = torch.cdist(attentions, pos_v, p=1)
    dist2 = torch.cdist(attentions, -neg_v, p=1)
    expected_context = 0.5 * (v_sum + dist1 - dist2)
    expected_context = expected_context.transpose(1, 2).contiguous().view(batch_size, seq_length, dim)
    expected_context = attention.out_lin(expected_context)
    assert torch.allclose(context, expected_context, atol=1e-4), "Signed inhibitor calculation is incorrect"

    print("All tests passed!")


In [14]:
test_imultihead_attention()

manhattan: tensor([[[[6.0772, 5.2858, 6.0801,  ..., 5.6185, 5.7522, 5.2338],
          [5.9760, 5.2336, 5.8107,  ..., 4.9800, 4.5297, 5.7370],
          [6.4357, 4.8345, 5.5585,  ..., 4.7197, 5.5091, 5.6305],
          ...,
          [5.0999, 4.7994, 5.2691,  ..., 5.1425, 4.5558, 5.3494],
          [5.5670, 4.3111, 4.8021,  ..., 4.8845, 5.0862, 5.7810],
          [5.2251, 5.9267, 5.5819,  ..., 6.2157, 5.5792, 6.9216]],

         [[5.0905, 6.3175, 4.7354,  ..., 5.7119, 5.8099, 6.0322],
          [5.1588, 5.2408, 5.3311,  ..., 4.9445, 5.7702, 5.6704],
          [4.6449, 5.7244, 4.7422,  ..., 4.3380, 6.0153, 4.6274],
          ...,
          [5.0343, 4.9869, 4.5839,  ..., 5.5890, 4.8580, 5.2585],
          [4.1412, 5.3241, 4.8187,  ..., 4.7849, 5.0593, 5.2501],
          [4.7986, 5.5238, 5.3539,  ..., 5.0259, 5.8376, 5.3515]],

         [[6.4377, 5.1312, 5.6195,  ..., 5.1679, 5.4747, 5.8638],
          [4.8503, 4.1265, 4.8440,  ..., 4.6355, 4.6717, 4.9571],
          [5.1298, 4.5253, 4.72

In [29]:
context = torch.tensor([1,2,3]) 
context *= 0.5
print(context)

RuntimeError: result type Float can't be cast to the desired output type Long

In [2]:
from datasets import load_dataset
dataset = load_dataset("bookcorpus/bookcorpus")

  from .autonotebook import tqdm as notebook_tqdm
Downloading builder script: 100%|██████████| 3.25k/3.25k [00:00<00:00, 6.86MB/s]
Downloading readme: 100%|██████████| 18.5k/18.5k [00:00<00:00, 14.4MB/s]


The repository for bookcorpus/bookcorpus contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/bookcorpus/bookcorpus.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


Downloading data: 100%|██████████| 1.18G/1.18G [00:35<00:00, 33.0MB/s] 
Generating train split: 100%|██████████| 74004228/74004228 [11:24<00:00, 108149.19 examples/s]


In [4]:
import torch
import torch.nn.functional as F

def test_cdist(dtype, simulated_dtype=None):
    # Create random tensors
    x = torch.randn(3, 5)
    y = torch.randn(2, 5)
    
    if simulated_dtype == 'int4':
        # Simulate int4 by clamping values and rounding
        x = torch.clamp(x * 7, -8, 7).round()
        y = torch.clamp(y * 7, -8, 7).round()
    
    # Convert to specified dtype
    x = x.to(dtype)
    y = y.to(dtype)
    
    # For integer types, we need to convert to float for cdist
    if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
        x_float = x.float()
        y_float = y.float()
    else:
        x_float = x
        y_float = y
    
    # Compute pairwise distance
    dist = torch.cdist(x_float, y_float)
    
    print(f"Testing cdist with {simulated_dtype or dtype}:")
    print(f"Input X shape: {x.shape}, dtype: {x.dtype}")
    print(f"Input Y shape: {y.shape}, dtype: {y.dtype}")
    print(f"Output shape: {dist.shape}, dtype: {dist.dtype}")
    print(f"Output:\n{dist}\n")

# Test with different dtypes
test_cdist(torch.float32)  # fp32
#test_cdist(torch.float16)  # fp16
test_cdist(torch.int8)     # int8
test_cdist(torch.float32, 'int4')  # simulated int4

Testing cdist with torch.float32:
Input X shape: torch.Size([3, 5]), dtype: torch.float32
Input Y shape: torch.Size([2, 5]), dtype: torch.float32
Output shape: torch.Size([3, 2]), dtype: torch.float32
Output:
tensor([[2.6425, 2.3108],
        [4.1014, 2.9532],
        [3.4673, 2.9974]])

Testing cdist with torch.int8:
Input X shape: torch.Size([3, 5]), dtype: torch.int8
Input Y shape: torch.Size([2, 5]), dtype: torch.int8
Output shape: torch.Size([3, 2]), dtype: torch.float32
Output:
tensor([[1.4142, 1.4142],
        [1.4142, 1.4142],
        [1.4142, 1.4142]])

Testing cdist with int4:
Input X shape: torch.Size([3, 5]), dtype: torch.float32
Input Y shape: torch.Size([2, 5]), dtype: torch.float32
Output shape: torch.Size([3, 2]), dtype: torch.float32
Output:
tensor([[20.4939, 15.4919],
        [27.9643, 15.2971],
        [17.3494, 13.4536]])

