In [1]:
import math
import random
import time

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
batch_size = 8
seq_lengths = [32, 64, 128, 256]
dims = [64, 128, 256, 512]

In [3]:
@torch.compile
def rademacher_sketch(A, B, sketch_size):
    S_shape = (*A.shape[:-2], A.shape[-1], sketch_size)
    S = (torch.randint(0, 2, S_shape, device=A.device) * 2 - 1).float()
    #S = torch.randn(S_shape, device=A.device)
    AS = torch.matmul(A, S) / sketch_size
    SB = torch.matmul(S.transpose(-1, -2), B)
    AB_bar = torch.matmul(AS, SB)
    print(S.shape, AS.shape, SB.shape, AB_bar.shape)
    return AB_bar

In [None]:
import torch
import time
import numpy as np

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Parameters
batch_size = 8
seq_lengths = [32, 64, 128, 256]
dims = [512, 1024, 2048, 4096]
sketch_sizes = [16, 32, 64]  # For RLA sketching
num_trials = 10  # Number of runs for stable timing

# Rademacher sketch function (from your code)
def rademacher_sketch(A, B, sketch_size):
    S_shape = (*A.shape[:-2], A.shape[-1], sketch_size)
    S = (torch.randint(0, 2, S_shape, device=A.device) * 2 - 1).float()
    AS = torch.matmul(A, S) / sketch_size  # Normalize for better approximation
    SB = torch.matmul(S.transpose(-1, -2), B)
    AB_bar = torch.matmul(AS, SB)
    return AB_bar

# Function to time matrix multiplication
def time_mm(A, B, method='standard', sketch_size=None):
    #torch.cuda.synchronize()  # Ensure GPU is ready
    start = time.time()
    
    if method == 'standard':
        result = torch.matmul(A, B)
    elif method == 'rademacher':
        result = rademacher_sketch(A, B, sketch_size)
    
    #torch.cuda.synchronize()  # Wait for GPU to finish
    end = time.time()
    return end - start, result

# Main experiment
results = []
for seq_len in seq_lengths:
    for dim in dims:
        # Input tensor (batch_size, seq_len, dim)
        token = torch.randn(batch_size, seq_len, dim, device=device) + 2
        
        # Linear transformation to get Q and K (simulating Transformer projection)
        linear_params = torch.randn(dim, dim, device=device) + 2
        Q = torch.matmul(token, linear_params)  # [batch_size, seq_len, dim]
        K = torch.matmul(token, linear_params)  # [batch_size, seq_len, dim]
        K_T = K.transpose(-1, -2)  # [batch_size, dim, seq_len] for QK^T
        
        # Standard MM: Q K^T
        standard_times = []
        for _ in range(num_trials):
            t, _ = time_mm(Q, K_T, method='standard')
            standard_times.append(t)
        standard_time = np.mean(standard_times)
        
        # RLA MM for different sketch sizes
        rla_times = {}
        for sketch_size in sketch_sizes:
            rla_times_trial = []
            for _ in range(num_trials):
                t, _ = time_mm(Q, K_T, method='rademacher', sketch_size=sketch_size)
                rla_times_trial.append(t)
            rla_times[sketch_size] = np.mean(rla_times_trial)
        
        # Store results
        result = {
            'seq_len': seq_len,
            'dim': dim,
            'standard_time': standard_time,
            'rla_times': rla_times,
            'speedups': {k: standard_time / rla_times[k] for k in rla_times}
        }
        results.append(result)

# Print results
print("\nTiming Results (seconds) and Speedup (Standard / RLA):")
print(f"{'Seq Len':>8} {'Dim':>6} {'Standard':>10} {'RLA k=16':>10} {'Speedup':>8} {'RLA k=32':>10} {'Speedup':>8} {'RLA k=64':>10} {'Speedup':>8}")
for res in results:
    print(f"{res['seq_len']:>8} {res['dim']:>6} {res['standard_time']:>10.6f} "
          f"{res['rla_times'][16]:>10.6f} {res['speedups'][16]:>8.3f} "
          f"{res['rla_times'][32]:>10.6f} {res['speedups'][32]:>8.3f} "
          f"{res['rla_times'][64]:>10.6f} {res['speedups'][64]:>8.3f}")


Using device: cpu
