In [7]:
import os
import torch.multiprocessing as mp
import torch.distributed as dist
import torch

def distributed_init():

    dist.init_process_group(backend="nccl")
    local_rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(local_rank)

    return local_rank, world_size

In [8]:
def find_free_port():
    """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """
    import socket
    from contextlib import closing

    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return str(s.getsockname()[1])

In [12]:
world_size = 2
master_addr = '127.0.0.1'
master_port = find_free_port()

In [13]:
master_port

'57727'

In [None]:
from datetime import timedelta

rank = 0
backend = 'nccl'

print(f'setting up {rank=} {world_size=} {backend=}')

# set up the master's ip address so this child process can coordinate
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
print(f"{master_addr=} {master_port=}")

# Initializes the default distributed process group, and this will also initialize the distributed package.
dist.init_process_group(backend, rank=rank, world_size=world_size, timeout=timedelta(seconds=30), init_method="env://")
print(f"{rank=} init complete")
dist.destroy_process_group()
print(f"{rank=} destroy complete")

In [None]:
CUDA_VISIBLE_DEVICES=6,7 OMP_NUM_THREADS=48 torchrun --nproc_per_node=2 test/TP_baseline.py

In [5]:
for i in range(3, 100):
    print(f"CUDA_VISIBLE_DEVICES=0 python test/e2e_ablation.py --prefill {122880} --budget {128*i} --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6")

CUDA_VISIBLE_DEVICES=0 python test/e2e_ablation.py --prefill 122880 --budget 384 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
CUDA_VISIBLE_DEVICES=0 python test/e2e_ablation.py --prefill 122880 --budget 512 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
CUDA_VISIBLE_DEVICES=0 python test/e2e_ablation.py --prefill 122880 --budget 640 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
CUDA_VISIBLE_DEVICES=0 python test/e2e_ablation.py --prefill 122880 --budget 768 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
CUDA_VISIBLE_DEVICES=0 python test/e2e_ablation.py --prefill 122880 --budget 896 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
CUDA_VISIBLE_DEVICES=0 python test/e2e_ablation.py --prefill 122880 --budget 1024 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
CUDA_VISIBLE_DEVICES=0 python test/e2e_ablation.py --prefill 122880 --budget 1152 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
CUDA_VISIBLE_DEVICES=0 python test/e2e_ablation.py --prefill 122880 --budget 1280 --chunk_size 

In [6]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("JackFram/llama-68m")

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [19]:
tokenizer.decode([263, 1568,  263,  590,  777,  263])

'a much a my some a'

In [24]:
bsz_list = [1, 2, 4, 8, 16, 32, 64, 128]
cuda_list = ['1', '1,2', '1,2,3,4', '1,2,3,4,5,6,8,9']

# CUDA_VISIBLE_DEVICES=5,6 OMP_NUM_THREADS=48 torchrun --nproc_per_node=2 test/TP_baseline.py --prefill 128 --bsz 1 --gen_len 32

for cuda in reversed(cuda_list):
    for bsz in reversed(bsz_list):
        print(f"CUDA_VISIBLE_DEVICES={cuda} OMP_NUM_THREADS=48 torchrun --nproc_per_node={len(cuda)//2+1} test/TP_baseline.py --prefill 128 --bsz {bsz} --gen_len 32")

CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,8,9 OMP_NUM_THREADS=48 torchrun --nproc_per_node=8 test/TP_baseline.py --prefill 128 --bsz 128 --gen_len 32
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,8,9 OMP_NUM_THREADS=48 torchrun --nproc_per_node=8 test/TP_baseline.py --prefill 128 --bsz 64 --gen_len 32
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,8,9 OMP_NUM_THREADS=48 torchrun --nproc_per_node=8 test/TP_baseline.py --prefill 128 --bsz 32 --gen_len 32
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,8,9 OMP_NUM_THREADS=48 torchrun --nproc_per_node=8 test/TP_baseline.py --prefill 128 --bsz 16 --gen_len 32
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,8,9 OMP_NUM_THREADS=48 torchrun --nproc_per_node=8 test/TP_baseline.py --prefill 128 --bsz 8 --gen_len 32
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,8,9 OMP_NUM_THREADS=48 torchrun --nproc_per_node=8 test/TP_baseline.py --prefill 128 --bsz 4 --gen_len 32
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,8,9 OMP_NUM_THREADS=48 torchrun --nproc_per_node=8 test/TP_baseline.py --prefill 128 --bsz 2 --gen_len 32
CUDA_VISIBLE_DEVICES=1

In [1]:
import torch

gt = torch.load('../gt.pt')
re = torch.load('../re.pt')

In [2]:
gt.keys()

dict_keys(['attn_output', 'key_states', 'value_states', 'query_states', 'position_ids'])

In [14]:
torch.allclose(gt['attn_output'], re['attn_output'], rtol=1e-4,atol=1e-4)

True

In [17]:
torch.allclose(gt['query_states'], re['query_states'], rtol=1e-3,atol=1e-3)

True

In [18]:
re['key_states'].shape

torch.Size([1, 129, 32, 128])

In [19]:
gt['key_states'].shape

torch.Size([1, 640, 32, 128])

In [7]:
torch.allclose(gt['key_states'][0][128], re['key_states'][0][128])

False

In [1]:
not False

True

In [8]:
gt['key_states'][0][0]

tensor([[-0.3528,  1.0146,  0.5854,  ...,  0.1339,  0.1257, -0.2622],
        [-0.2218, -0.5117,  0.0141,  ...,  0.8779,  0.8389,  0.8540],
        [-0.1284,  0.5469,  0.6167,  ...,  1.6211,  0.9756,  0.1910],
        ...,
        [-1.0947,  0.0994,  0.0122,  ...,  1.2227, -0.4485, -0.0783],
        [ 0.1693,  0.2878, -0.2220,  ..., -0.6587,  0.9082, -0.0146],
        [ 0.0875,  0.0104,  0.1571,  ...,  0.3452, -0.3425,  0.2004]],
       device='cuda:0', dtype=torch.float16)

In [9]:
re['key_states'][0][0]

tensor([[-0.3528,  1.0146,  0.5854,  ...,  0.1339,  0.1257, -0.2622],
        [-0.2218, -0.5117,  0.0141,  ...,  0.8779,  0.8389,  0.8540],
        [-0.1284,  0.5469,  0.6167,  ...,  1.6211,  0.9756,  0.1910],
        ...,
        [-1.0947,  0.0994,  0.0122,  ...,  1.2227, -0.4485, -0.0783],
        [ 0.1693,  0.2878, -0.2220,  ..., -0.6587,  0.9082, -0.0146],
        [ 0.0875,  0.0104,  0.1571,  ...,  0.3452, -0.3425,  0.2004]],
       device='cuda:0', dtype=torch.float16)

In [10]:
re['key_states'][0].shape

torch.Size([129, 32, 128])

In [11]:
for head in range(32):
    for token in range(129):
        to_find = gt['key_states'][0][token][head]
        found = False
        for i in range(129):
            if torch.allclose(to_find, re['key_states'][0][i][head]):
                found = True
                break
        if not found:
            print(f"head {head} token {token} not found")

    for token in range(129):
        to_find = gt['value_states'][0][token][head]
        found = False
        for i in range(129):
            if torch.allclose(to_find, re['value_states'][0][i][head]):
                found = True
                break
        if not found:
            print(f"head {head} token {token} not found")
    print(f"head {head} done")

head 0 token 128 not found
head 0 token 128 not found
head 0 done
head 1 token 128 not found
head 1 token 128 not found
head 1 done
head 2 token 128 not found
head 2 token 128 not found
head 2 done
head 3 token 128 not found
head 3 token 128 not found
head 3 done
head 4 token 128 not found
head 4 token 128 not found
head 4 done
head 5 token 128 not found
head 5 token 128 not found
head 5 done
head 6 token 128 not found
head 6 token 128 not found
head 6 done
head 7 token 128 not found
head 7 token 128 not found
head 7 done
head 8 token 128 not found
head 8 token 128 not found
head 8 done
head 9 token 128 not found
head 9 token 128 not found
head 9 done
head 10 token 128 not found
head 10 token 128 not found
head 10 done
head 11 token 128 not found
head 11 token 128 not found
head 11 done
head 12 token 128 not found
head 12 token 128 not found
head 12 done
head 13 token 128 not found
head 13 token 128 not found
head 13 done
head 14 token 128 not found
head 14 token 128 not found
head 14 

In [68]:
from flash_attn import flash_attn_with_kvcache

flash_attn_with_kvcache(q=re['query_states'], k_cache=re['key_states'], v_cache=re['value_states'], softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True)

tensor([[[[-6.0692e-03, -4.2272e-04,  2.1660e-04,  ..., -6.7062e-03,
            3.6836e-04, -9.9640e-03],
          [-4.0398e-03, -1.3342e-03,  5.1796e-05,  ..., -2.6207e-03,
            9.5177e-04, -7.3195e-04],
          [ 9.0504e-04,  2.3041e-03,  5.5962e-03,  ...,  3.0384e-03,
            3.1281e-03,  3.0594e-03],
          ...,
          [-3.2539e-03,  1.1887e-02, -1.7822e-02,  ...,  6.0310e-03,
            2.9774e-03, -1.3466e-02],
          [-9.0179e-03, -3.5858e-03,  2.0409e-03,  ...,  6.1493e-03,
            9.6054e-03,  1.3199e-02],
          [-6.7291e-03,  2.5463e-03, -4.2114e-03,  ..., -8.3542e-03,
           -9.3651e-04,  4.4365e-03]]]], device='cuda:0', dtype=torch.float16)

In [69]:
gt['attn_output']

tensor([[[-0.0061, -0.0004,  0.0002,  ..., -0.0084, -0.0009,  0.0044]]],
       device='cuda:0', dtype=torch.float16)

In [70]:
torch.allclose(gt['attn_output'],flash_attn_with_kvcache(q=re['query_states'], k_cache=re['key_states'], v_cache=re['value_states'], softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True).reshape(1, 1, 4096))

False

In [71]:
flash_attn_with_kvcache(q=re['query_states'], k_cache=re['key_states'], v_cache=re['value_states'], softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True).shape

torch.Size([1, 1, 32, 128])

In [72]:
flash_attn_with_kvcache(q=re['query_states'], k_cache=re['key_states'], v_cache=re['value_states'], softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True).reshape(1, 1, 4096)

tensor([[[-0.0061, -0.0004,  0.0002,  ..., -0.0084, -0.0009,  0.0044]]],
       device='cuda:0', dtype=torch.float16)

In [73]:
torch.tensor([[128]], dtype=torch.int32).cuda()

tensor([[128]], device='cuda:0', dtype=torch.int32)

In [74]:
flash_attn_with_kvcache(q=gt['query_states'], k_cache=gt['key_states'], v_cache=gt['value_states'], softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True, cache_seqlens=torch.tensor([129], dtype=torch.int32).cuda()).reshape(1, 1, 4096)

tensor([[[-0.0061, -0.0004,  0.0002,  ..., -0.0084, -0.0009,  0.0044]]],
       device='cuda:0', dtype=torch.float16)

In [75]:
flash_attn_with_kvcache(q=re['query_states'], k_cache=re['key_states'], v_cache=re['value_states'], softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True, ).reshape(1, 1, 4096)

tensor([[[-0.0061, -0.0004,  0.0002,  ..., -0.0084, -0.0009,  0.0044]]],
       device='cuda:0', dtype=torch.float16)

In [76]:
flash_attn_with_kvcache(q=gt['query_states'], k_cache=gt['key_states'], v_cache=gt['value_states'], softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True, cache_seqlens=torch.tensor([129], dtype=torch.int32).cuda()).reshape(1, 1, 4096)

tensor([[[-0.0061, -0.0004,  0.0002,  ..., -0.0084, -0.0009,  0.0044]]],
       device='cuda:0', dtype=torch.float16)

In [88]:
torch.allclose(re['attn_output'], flash_attn_with_kvcache(q=gt['query_states'], k_cache=gt['key_states'][:,:130], v_cache=gt['value_states'][:,:130], softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True, cache_seqlens=torch.tensor([129], dtype=torch.int32).cuda()).reshape(1, 1, 4096), rtol=1e-4, atol=1e-4)

True

In [45]:
flash_attn_with_kvcache(q=gt['query_states'], k_cache=gt['key_states'][:,:330], v_cache=gt['value_states'][:,:330], softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True, cache_seqlens=torch.tensor([129], dtype=torch.int32).cuda()).reshape(1, 1, 4096)

tensor([[[-0.0061, -0.0004,  0.0002,  ..., -0.0084, -0.0009,  0.0044]]],
       device='cuda:0', dtype=torch.float16)

In [37]:
gt['key_states'].shape

torch.Size([1, 896, 32, 128])

In [33]:
torch.allclose(re['query_states'], gt['query_states'])

True

In [5]:
bsz_list = [1, 2, 3, 4, 5, 6, 7, 8]
prefill_list = [122880, 64*1024, 32*1024, 16*1024]

for prefill in prefill_list:
    for bsz in bsz_list:
        print(f"CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz {bsz} --prefill {prefill} --attn_method flash_repeat")

CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 1 --prefill 122880 --attn_method flash_repeat
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 2 --prefill 122880 --attn_method flash_repeat
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 3 --prefill 122880 --attn_method flash_repeat
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 4 --prefill 122880 --attn_method flash_repeat
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 5 --prefill 122880 --attn_method flash_repeat
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 6 --prefill 122880 --attn_method flash_repeat
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 7 --prefill 122880 --attn_method flash_repeat
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 8 --prefill 122880 --attn_method flash_repeat
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 1 --prefill 65536 --attn_method flash_repeat
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --b