In [5]:
import time
import torch

# Suppose you want num_streams=2, batch_size=4, so total=8 samples
num_streams = 2
batch_size = 4
M, N = 1000, 1000

# Create a single tensor of shape (total, M, N)
total = num_streams * batch_size
x_total = torch.randn(total, M, N, device='cuda')

# Split x_total into 2 chunks of size 4
chunks = x_total.chunk(num_streams, dim=0)

# Preallocate outputs
outs_stream = [torch.empty_like(chunks[i].transpose(1,2).bmm(chunks[i])) 
               for i in range(num_streams)]
streams = [torch.cuda.Stream() for _ in range(num_streams)]

torch.cuda.synchronize()
t0 = time.time()
for i, (inp, out, stream) in enumerate(zip(chunks, outs_stream, streams)):
    with torch.cuda.stream(stream):
        out.copy_(torch.bmm(inp, inp.transpose(1,2)))
torch.cuda.synchronize()
stream_time = time.time() - t0


from torch.func import vmap

# Define per-sample function
def f_single(x):
    return x @ x.transpose(0, 1)  # (M,N) -> (M,M)

# Batch-lift it
f_vmap = vmap(f_single)

torch.cuda.synchronize()
t0 = time.time()
outs_vmap = f_vmap(x_total)    # processes all 8 samples
torch.cuda.synchronize()
vmap_time = time.time() - t0

outs_seq = torch.empty(total, M, M, device='cuda')

torch.cuda.synchronize()
t0 = time.time()
for i in range(total):
    outs_seq[i].copy_(x_total[i] @ x_total[i].transpose(0,1))
torch.cuda.synchronize()
seq_time = time.time() - t0

print(f"Stream (2×4) : {stream_time:.4f} s")
print(f"Sequential(8): {seq_time:.4f} s")
print(f"vmap(8)      : {vmap_time:.4f} s")


Stream (2×4) : 0.0029 s
Sequential(8): 0.0009 s
vmap(8)      : 0.0009 s
