<a href="https://colab.research.google.com/github/ved1beta/Triton/blob/main/falseattention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install torch triton numpy matplotlib datasets transformers seaborn


Collecting triton
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48

In [8]:
import torch
import triton
import triton.language as tl
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import BertTokenizer
import time
from torch.utils.data import DataLoader
import seaborn as sns
from tqdm import tqdm

# Import the Triton attention implementation from the original code
@triton.jit
def attention_qk_kernel(
    q_ptr, k_ptr, output_ptr,
    seq_length, head_size,
    BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(0)
    row_idx = pid // (seq_length // BLOCK_SIZE)
    col_idx = pid % (seq_length // BLOCK_SIZE)

    q_start = q_ptr + row_idx * BLOCK_SIZE * head_size
    k_start = k_ptr + col_idx * BLOCK_SIZE * head_size

    acc = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)

    for i in range(0, head_size, BLOCK_SIZE):
        q_block = tl.load(q_start + i)
        k_block = tl.load(k_start + i)
        acc += tl.dot(q_block, tl.trans(k_block))

    acc = acc / tl.sqrt(float(head_size))

    output_start = output_ptr + row_idx * BLOCK_SIZE * seq_length + col_idx * BLOCK_SIZE
    tl.store(output_start, acc)

@triton.jit
def softmax_kernel(
    ptr_in, ptr_out,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0)
    row_start = ptr_in + row_idx * n_cols

    row = tl.load(row_start + tl.arange(0, n_cols))
    row_max = tl.max(row, axis=0)
    numerator = tl.exp(row - row_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

    tl.store(ptr_out + row_idx * n_cols + tl.arange(0, n_cols), softmax_output)

class TritonAttention(torch.nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size

    def forward(self, q, k, v):
        batch_size = q.shape[0]
        seq_len = q.shape[1]

        grid = (batch_size * seq_len * seq_len) // 256

        attention_scores = torch.empty(
            (batch_size, seq_len, seq_len),
            device='cuda',
            dtype=torch.float32
        )

        attention_qk_kernel[grid](
            q.data_ptr(),
            k.data_ptr(),
            attention_scores.data_ptr(),
            seq_len,
            self.head_size,
            BLOCK_SIZE=16
        )

        attention_probs = torch.empty_like(attention_scores)
        softmax_kernel[grid](
            attention_scores.data_ptr(),
            attention_probs.data_ptr(),
            seq_len,
            BLOCK_SIZE=256
        )

        output = torch.empty_like(q)
        attention_qk_kernel[grid](
            attention_probs.data_ptr(),
            v.data_ptr(),
            output.data_ptr(),
            seq_len,
            self.head_size,
            BLOCK_SIZE=16
        )

        return output

def load_and_prepare_data(batch_size=32, max_length=256):
    """Load and prepare the Wikipedia dataset"""
    print("Loading dataset...")
    dataset = load_dataset("wikipedia", "20220301.en", split="train[:1000]")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            padding='max_length',
            truncation=True,
            max_length=max_length
        )

    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    return dataloader

def run_attention_comparison(dataloader, head_size=64, num_batches=10):
    """Run comparison between PyTorch and Triton attention"""
    triton_attention = TritonAttention(head_size).cuda()

    triton_times = []
    pytorch_times = []
    accuracy_diffs = []

    print("Running attention comparison...")
    for batch in tqdm(list(dataloader)[:num_batches]):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        # Create random Q, K, V tensors based on the input shape
        batch_size = input_ids.shape[0]
        seq_len = input_ids.shape[1]

        q = torch.randn(batch_size, seq_len, head_size, device='cuda')
        k = torch.randn(batch_size, seq_len, head_size, device='cuda')
        v = torch.randn(batch_size, seq_len, head_size, device='cuda')

        # Warm-up run
        _ = triton_attention(q, k, v)
        _ = torch.nn.functional.scaled_dot_product_attention(q, k, v)

        # Measure Triton performance
        torch.cuda.synchronize()
        start_time = time.time()
        triton_output = triton_attention(q, k, v)
        torch.cuda.synchronize()
        triton_time = time.time() - start_time
        triton_times.append(triton_time)

        # Measure PyTorch performance
        torch.cuda.synchronize()
        start_time = time.time()
        pytorch_output = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        torch.cuda.synchronize()
        pytorch_time = time.time() - start_time
        pytorch_times.append(pytorch_time)

        # Calculate accuracy difference
        max_diff = torch.max(torch.abs(triton_output - pytorch_output))
        accuracy_diffs.append(max_diff.item())

    return triton_times, pytorch_times, accuracy_diffs

def plot_results(triton_times, pytorch_times, accuracy_diffs):
    """Plot performance comparison results"""
    plt.figure(figsize=(15, 5))

    # Performance comparison
    plt.subplot(1, 2, 1)
    data = {
        'Triton': triton_times,
        'PyTorch': pytorch_times
    }
    sns.boxplot(data=data)
    plt.title('Attention Implementation Performance Comparison')
    plt.ylabel('Time (seconds)')

    # Accuracy differences
    plt.subplot(1, 2, 2)
    plt.plot(accuracy_diffs, marker='o')
    plt.title('Maximum Difference between Implementations')
    plt.xlabel('Batch')
    plt.ylabel('Maximum Absolute Difference')

    plt.tight_layout()
    plt.savefig('attention_comparison_results.png')
    print("Results saved to attention_comparison_results.png")

def main():
    # Configuration
    batch_size = 32
    max_length = 256
    head_size = 64
    num_batches = 10

    # Load data
    dataloader = load_and_prepare_data(batch_size, max_length)

    # Run comparison
    triton_times, pytorch_times, accuracy_diffs = run_attention_comparison(
        dataloader,
        head_size,
        num_batches
    )

    # Plot and save results
    plot_results(triton_times, pytorch_times, accuracy_diffs)

    # Print summary statistics
    print("\nPerformance Summary:")
    print(f"Average Triton Time: {np.mean(triton_times):.4f} seconds")
    print(f"Average PyTorch Time: {np.mean(pytorch_times):.4f} seconds")
    print(f"Average Maximum Difference: {np.mean(accuracy_diffs):.6f}")
    print(f"Speedup: {np.mean(pytorch_times)/np.mean(triton_times):.2f}x")

if __name__ == "__main__":
    main()

Loading dataset...
Running attention comparison...


  0%|          | 0/10 [00:00<?, ?it/s]


TypeError: only integer tensors of a single element can be converted to an index

In [9]:
import torch
import time
from tqdm import tqdm
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import triton
import triton.language as tl

@triton.jit
def attention_qk_kernel(
    q_ptr, k_ptr, output_ptr,
    seq_length, head_size,
    BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(0)
    row_idx = pid // (seq_length // BLOCK_SIZE)
    col_idx = pid % (seq_length // BLOCK_SIZE)

    q_start = q_ptr + row_idx * BLOCK_SIZE * head_size
    k_start = k_ptr + col_idx * BLOCK_SIZE * head_size

    acc = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)

    for i in range(0, head_size, BLOCK_SIZE):
        q_block = tl.load(q_start + i)
        k_block = tl.load(k_start + i)
        acc += tl.dot(q_block, tl.trans(k_block))

    acc = acc / tl.sqrt(float(head_size))

    output_start = output_ptr + row_idx * BLOCK_SIZE * seq_length + col_idx * BLOCK_SIZE
    tl.store(output_start, acc)

@triton.jit
def softmax_kernel(
    ptr_in, ptr_out,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0)
    row_start = ptr_in + row_idx * n_cols

    row = tl.load(row_start + tl.arange(0, n_cols))
    row_max = tl.max(row, axis=0)
    numerator = tl.exp(row - row_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

    tl.store(ptr_out + row_idx * n_cols + tl.arange(0, n_cols), softmax_output)

class TritonAttention(torch.nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size

    def forward(self, q, k, v):
        batch_size = q.shape[0]
        seq_len = q.shape[1]

        grid = (batch_size * seq_len * seq_len) // 256

        attention_scores = torch.empty(
            (batch_size, seq_len, seq_len),
            device='cuda',
            dtype=torch.float32
        )

        attention_qk_kernel[grid](
            q.data_ptr(),
            k.data_ptr(),
            attention_scores.data_ptr(),
            seq_len,
            self.head_size,
            BLOCK_SIZE=16
        )

        attention_probs = torch.empty_like(attention_scores)
        softmax_kernel[grid](
            attention_scores.data_ptr(),
            attention_probs.data_ptr(),
            seq_len,
            BLOCK_SIZE=256
        )

        output = torch.empty_like(q)
        attention_qk_kernel[grid](
            attention_probs.data_ptr(),
            v.data_ptr(),
            output.data_ptr(),
            seq_len,
            self.head_size,
            BLOCK_SIZE=16
        )

        return output

def load_and_prepare_data(batch_size=32, max_length=256):
    """Load and prepare the Wikipedia dataset"""
    print("Loading dataset...")
    dataset = load_dataset("wikipedia", "20220301.en", split="train[:1000]")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            padding='max_length',
            truncation=True,
            max_length=max_length
        )

    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    return dataloader

def run_attention_comparison(dataloader, head_size=64, num_batches=10):
    """Run comparison between PyTorch and Triton attention"""
    triton_attention = TritonAttention(head_size).cuda()

    triton_times = []
    pytorch_times = []
    accuracy_diffs = []

    print("Running attention comparison...")
    for batch in tqdm(list(dataloader)[:num_batches]):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        # Ensure input_ids and attention_mask are tensors
        if isinstance(input_ids, list):
            input_ids = torch.tensor(input_ids, device='cuda')
        if isinstance(attention_mask, list):
            attention_mask = torch.tensor(attention_mask, device='cuda')

        # Create random Q, K, V tensors based on the input shape
        batch_size = input_ids.shape[0]
        seq_len = input_ids.shape[1]

        q = torch.randn(batch_size, seq_len, head_size, device='cuda')
        k = torch.randn(batch_size, seq_len, head_size, device='cuda')
        v = torch.randn(batch_size, seq_len, head_size, device='cuda')

        # Warm-up run
        _ = triton_attention(q, k, v)
        _ = torch.nn.functional.scaled_dot_product_attention(q, k, v)

        # Measure Triton performance
        torch.cuda.synchronize()
        start_time = time.time()
        triton_output = triton_attention(q, k, v)
        torch.cuda.synchronize()
        triton_time = time.time() - start_time
        triton_times.append(triton_time)

        # Measure PyTorch performance
        torch.cuda.synchronize()
        start_time = time.time()
        pytorch_output = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        torch.cuda.synchronize()
        pytorch_time = time.time() - start_time
        pytorch_times.append(pytorch_time)

        # Calculate accuracy difference
        max_diff = torch.max(torch.abs(triton_output - pytorch_output))
        accuracy_diffs.append(max_diff.item())

    return triton_times, pytorch_times, accuracy_diffs

if __name__ == "__main__":
    dataloader = load_and_prepare_data()
    triton_times, pytorch_times, accuracy_diffs = run_attention_comparison(dataloader)
    print("Triton times:", triton_times)
    print("PyTorch times:", pytorch_times)
    print("Accuracy differences:", accuracy_diffs)


Loading dataset...
Running attention comparison...


  0%|          | 0/10 [00:00<?, ?it/s]


TypeError: only integer tensors of a single element can be converted to an index

In [21]:
import torch
import time
from tqdm import tqdm
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import triton
import triton.language as tl

@triton.jit
def attention_qk_kernel(
    q_ptr, k_ptr, output_ptr,
    seq_length, head_size,
    BLOCK_SIZE: tl.constexpr
):
    # Compute program ID
    pid = tl.program_id(0)
    row_idx = pid // (seq_length // BLOCK_SIZE)
    col_idx = pid % (seq_length // BLOCK_SIZE)

    # Compute pointer offsets for Q and K (without casting to integer types)
    q_start = q_ptr + row_idx * BLOCK_SIZE * head_size
    k_start = k_ptr + col_idx * BLOCK_SIZE * head_size

    # Accumulator for the attention scores
    acc = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)

    # Iterate over head_size in BLOCK_SIZE chunks
    for i in range(0, head_size, BLOCK_SIZE):
        # Compute chunk offsets for Q and K
        q_chunk = q_start + i
        k_chunk = k_start + i

        # Load Q and K blocks
        q_block = tl.load(q_chunk, mask=None, other=0.0)
        k_block = tl.load(k_chunk, mask=None, other=0.0)

        # Compute dot product and accumulate the result
        acc += tl.dot(q_block, tl.trans(k_block))

    # Normalize the attention scores
    scale = tl.sqrt(float(head_size))
    acc = acc / scale

    # Compute the output pointer offset
    output_start = output_ptr + row_idx * BLOCK_SIZE * seq_length + col_idx * BLOCK_SIZE

    # Store the result
    tl.store(output_start, acc)



@triton.jit
def softmax_kernel(
    ptr_in, ptr_out,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0)
    row_start = ptr_in + row_idx * n_cols

    row = tl.load(row_start + tl.arange(0, n_cols))
    row_max = tl.max(row, axis=0)
    numerator = tl.exp(row - row_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

    tl.store(ptr_out + row_idx * n_cols + tl.arange(0, n_cols), softmax_output)

class TritonAttention(torch.nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size

    def forward(self, q, k, v):
        batch_size = q.shape[0]
        seq_len = q.shape[1]

        grid = (batch_size * seq_len * seq_len) // 256

        attention_scores = torch.empty(
            (batch_size, seq_len, seq_len),
            device='cuda',
            dtype=torch.float32
        )

        attention_qk_kernel[grid](
            q.data_ptr(),
            k.data_ptr(),
            attention_scores.data_ptr(),
            seq_len,
            self.head_size,
            BLOCK_SIZE=16
        )

        attention_probs = torch.empty_like(attention_scores)
        softmax_kernel[grid](
            attention_scores.data_ptr(),
            attention_probs.data_ptr(),
            seq_len,
            BLOCK_SIZE=256
        )

        output = torch.empty_like(q)
        attention_qk_kernel[grid](
            attention_probs.data_ptr(),
            v.data_ptr(),
            output.data_ptr(),
            seq_len,
            self.head_size,
            BLOCK_SIZE=16
        )

        return output

def load_and_prepare_data(batch_size=32, max_length=256):
    """Load and prepare the Wikipedia dataset"""
    print("Loading dataset...")
    dataset = load_dataset("wikipedia", "20220301.en", split="train[:1000]")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            padding='max_length',
            truncation=True,
            max_length=max_length
        )

    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    return dataloader

def run_attention_comparison(dataloader, head_size=64, num_batches=10):
    """Run comparison between PyTorch and Triton attention"""
    triton_attention = TritonAttention(head_size).cuda()

    triton_times = []
    pytorch_times = []
    accuracy_diffs = []

    print("Running attention comparison...")
    for batch in tqdm(list(dataloader)[:num_batches]):
        print(batch)  # Debugging: Check the batch structure
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        # Convert to PyTorch tensors if necessary
        if isinstance(input_ids, list):
            input_ids = torch.stack([torch.tensor(x, dtype=torch.long) for x in input_ids], dim=0).to('cuda')
        elif isinstance(input_ids, torch.Tensor):
            input_ids = input_ids.to('cuda')
        else:
            raise ValueError("Unexpected format for input_ids")

        if isinstance(attention_mask, list):
            attention_mask = torch.stack([torch.tensor(x, dtype=torch.long) for x in attention_mask], dim=0).to('cuda')
        elif isinstance(attention_mask, torch.Tensor):
            attention_mask = attention_mask.to('cuda')
        else:
            raise ValueError("Unexpected format for attention_mask")

        # Create random Q, K, V tensors based on the input shape
        batch_size = input_ids.shape[0]
        seq_len = input_ids.shape[1]

        q = torch.randn(batch_size, seq_len, head_size, device='cuda')
        k = torch.randn(batch_size, seq_len, head_size, device='cuda')
        v = torch.randn(batch_size, seq_len, head_size, device='cuda')

        # Warm-up run
        _ = triton_attention(q, k, v)
        _ = torch.nn.functional.scaled_dot_product_attention(q, k, v)

        # Measure Triton performance
        torch.cuda.synchronize()
        start_time = time.time()
        triton_output = triton_attention(q, k, v)
        torch.cuda.synchronize()
        triton_time = time.time() - start_time
        triton_times.append(triton_time)

        # Measure PyTorch performance
        torch.cuda.synchronize()
        start_time = time.time()
        pytorch_output = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        torch.cuda.synchronize()
        pytorch_time = time.time() - start_time
        pytorch_times.append(pytorch_time)

        # Calculate accuracy difference
        max_diff = torch.max(torch.abs(triton_output - pytorch_output))
        accuracy_diffs.append(max_diff.item())

    return triton_times, pytorch_times, accuracy_diffs


if __name__ == "__main__":
    dataloader = load_and_prepare_data()
    triton_times, pytorch_times, accuracy_diffs = run_attention_comparison(dataloader)
    print("Triton times:", triton_times)
    print("PyTorch times:", pytorch_times)
    print("Accuracy differences:", accuracy_diffs)


Loading dataset...
Running attention comparison...


  input_ids = torch.stack([torch.tensor(x, dtype=torch.long) for x in input_ids], dim=0).to('cuda')
  attention_mask = torch.stack([torch.tensor(x, dtype=torch.long) for x in attention_mask], dim=0).to('cuda')
  0%|          | 0/10 [00:00<?, ?it/s]

        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101]), tensor([17712,  9348,  3656, 14813,  2632,  1996,  2824,  2004, 22480,  2572,
        25542,  1996,  6647, 11815,  3935,  1999, 20749, 21358,  2004,  1996,
         1996,  1996, 20312,  5557,  2909,  4300,  3656, 24766,  2004,  8285,
        17076,  3656]), tensor([ 3695,  2260,  1997,  2015, 24141,  5334,  3653,  3334, 18505, 12618,
        12456,  1045,  8740,  8338,  6370,  3418,  1010,  5644, 10085,  2773,
         6041,  9779,  1006,  2162,  3656, 17243,  3523,  1055, 29336,  3868,
        19231,  3523]), tensor([22993,  1006,  5483, 10812,  1010,  9310,  1011,  2964,  3215, 24721,
        11592, 17603, 16570,  2003,  2003, 28625,  2030,  2080, 20469, 29347,
         2314,  4103,  1010, 14854, 13779,  4679,  1006,  1012,  3170,  1006,
         2721,  1006]), tensor([ 2696,  2281,  1006,  1006,  4351,  2024, 14883,  2089,  2024,  8740,
         2121,  3900,  4173,  1996,  103




CompilationError: at 25:18:

    # Accumulator for the attention scores
    acc = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)

    # Iterate over head_size in BLOCK_SIZE chunks
    for i in range(0, head_size, BLOCK_SIZE):
        # Compute chunk offsets for Q and K
        q_chunk = q_start + i
        k_chunk = k_start + i

        # Load Q and K blocks
        q_block = tl.load(q_chunk, mask=None, other=0.0)
                  ^