Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama shard axis 0 sometimes #5123

Merged
merged 6 commits into from
Jun 26, 2024
Merged

Conversation

chaosagent
Copy link
Contributor

@chaosagent chaosagent commented Jun 24, 2024

depends #5122

for 2 consecutive linear layers, if first is sharded axis=0, then no all2all comms are needed between the 2 layers

if embedding is sharded -1, then activations are sharded going into transformer blocks and layer norms, resulting in a lot of extra all2all rounds

note that embedding is (in_features, out_features), so sharding 0 means that the entire embedding is fetched on the 1 device that has it (with the rest having 0 vectors) and all reduced. ideally, we would shard -1 and force it to all-gather afterwards, but there is no way to do this in the current api.

benchmarks: (note that only changing both embeddings and w1/w3 helps):

chaos@tiny13:~/tinygrad$ JITBEAM=4 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0  --timing
using AMD backend
using LLaMA-2-70B model
ram used: 137.97 GB, freqs_cis                                         : 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 724/724 [00:22<00:00, 32.40it/s]
loaded weights in 22351.37 ms, 137.97 GB loaded at 6.17 GB/s
Hello.
enqueue in 15591.43 ms
total 15592.59 ms, 0.06 tok/s, 9.57 GB/s, param 8.85 GB/s
 I
enqueue in 12874.70 ms
total 12875.18 ms, 0.08 tok/s, 10.80 GB/s, param 10.72 GB/s
 am
enqueue in 124206.95 ms
total 124208.07 ms, 0.01 tok/s, 1.12 GB/s, param 1.11 GB/s
 a
enqueue in  33.92 ms
total  72.58 ms, 13.78 tok/s, 1926.59 GB/s, param 1901.09 GB/s
 
enqueue in  40.08 ms
total  69.83 ms, 14.32 tok/s, 2002.82 GB/s, param 1975.94 GB/s
2
enqueue in  40.45 ms
total  69.82 ms, 14.32 tok/s, 2003.57 GB/s, param 1976.30 GB/s
0
enqueue in  41.65 ms
total  71.11 ms, 14.06 tok/s, 1967.59 GB/s, param 1940.44 GB/s
 year
enqueue in  37.58 ms
total  67.25 ms, 14.87 tok/s, 2080.70 GB/s, param 2051.61 GB/s
 old
enqueue in  40.70 ms
total  70.32 ms, 14.22 tok/s, 1990.34 GB/s, param 1962.14 GB/s
 female
enqueue in  41.81 ms
total  71.56 ms, 13.97 tok/s, 1956.20 GB/s, param 1928.12 GB/s
.
output validated

With just the embeddings change:

chaos@tiny13:~/tinygrad$ JITBEAM=4 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0  --timing
using AMD backend
using LLaMA-2-70B model
ram used: 137.97 GB, freqs_cis                                         : 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 724/724 [00:22<00:00, 32.66it/s]
loaded weights in 22172.30 ms, 137.97 GB loaded at 6.22 GB/s
Hello.
enqueue in 14666.25 ms
total 14667.33 ms, 0.07 tok/s, 10.44 GB/s, param 9.41 GB/s
 I
enqueue in 12092.77 ms
total 12093.23 ms, 0.08 tok/s, 11.61 GB/s, param 11.41 GB/s
 am
enqueue in 32450.43 ms
total 32451.64 ms, 0.03 tok/s, 4.33 GB/s, param 4.25 GB/s
 a
enqueue in  51.34 ms
total 107.64 ms, 9.29 tok/s, 1307.65 GB/s, param 1281.83 GB/s
 
enqueue in  52.16 ms
total 102.00 ms, 9.80 tok/s, 1380.29 GB/s, param 1352.78 GB/s
2
enqueue in  51.41 ms
total 100.69 ms, 9.93 tok/s, 1398.41 GB/s, param 1370.28 GB/s
0
enqueue in  51.69 ms
total 101.34 ms, 9.87 tok/s, 1389.76 GB/s, param 1361.56 GB/s
 year
enqueue in  45.10 ms
total 101.12 ms, 9.89 tok/s, 1393.05 GB/s, param 1364.52 GB/s
 old
enqueue in  51.68 ms
total 101.32 ms, 9.87 tok/s, 1390.60 GB/s, param 1361.87 GB/s
 female
enqueue in  52.28 ms
total 101.81 ms, 9.82 tok/s, 1384.12 GB/s, param 1355.27 GB/s
.
output validated

With just the w1/w3 change:

chaos@tiny13:~/tinygrad$ JITBEAM=4 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0  --timing
using AMD backend
using LLaMA-2-70B model
ram used: 137.97 GB, freqs_cis                                         : 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 724/724 [00:22<00:00, 32.67it/s]
loaded weights in 22161.13 ms, 137.97 GB loaded at 6.23 GB/s
Hello.
enqueue in 15478.05 ms
total 15479.21 ms, 0.06 tok/s, 9.57 GB/s, param 8.91 GB/s
 I
enqueue in 12866.46 ms
total 12866.93 ms, 0.08 tok/s, 10.78 GB/s, param 10.72 GB/s
 am
enqueue in 16945.41 ms
total 16946.63 ms, 0.06 tok/s, 8.19 GB/s, param 8.14 GB/s
 a
enqueue in  55.49 ms
total 106.19 ms, 9.42 tok/s, 1307.03 GB/s, param 1299.31 GB/s
 
enqueue in  63.03 ms
total 108.91 ms, 9.18 tok/s, 1274.70 GB/s, param 1266.93 GB/s
2
enqueue in  60.59 ms
total 106.66 ms, 9.38 tok/s, 1301.81 GB/s, param 1293.62 GB/s
0
enqueue in  62.94 ms
total 108.84 ms, 9.19 tok/s, 1275.95 GB/s, param 1267.68 GB/s
 year
enqueue in  62.01 ms
total 108.22 ms, 9.24 tok/s, 1283.55 GB/s, param 1275.00 GB/s
 old
enqueue in  63.54 ms
total 109.62 ms, 9.12 tok/s, 1267.37 GB/s, param 1258.68 GB/s
 female
enqueue in  62.49 ms
total 108.56 ms, 9.21 tok/s, 1280.01 GB/s, param 1271.00 GB/s
.
output validated

baseline:

chaos@tiny13:~/tinygrad$ JITBEAM=4 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0  --timing
using AMD backend
using LLaMA-2-70B model
ram used: 137.97 GB, freqs_cis                                         : 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 724/724 [00:22<00:00, 32.42it/s]
loaded weights in 22331.48 ms, 137.97 GB loaded at 6.18 GB/s
Hello.
enqueue in 14730.28 ms
total 14731.31 ms, 0.07 tok/s, 10.32 GB/s, param 9.37 GB/s
 I
enqueue in 12173.18 ms
total 12173.60 ms, 0.08 tok/s, 11.50 GB/s, param 11.33 GB/s
 am
enqueue in 15047.83 ms
total 15049.00 ms, 0.07 tok/s, 9.31 GB/s, param 9.17 GB/s
 a
enqueue in  55.27 ms
total 101.20 ms, 9.88 tok/s, 1384.13 GB/s, param 1363.45 GB/s
 
enqueue in  63.20 ms
total 108.59 ms, 9.21 tok/s, 1290.14 GB/s, param 1270.63 GB/s
2
enqueue in  59.28 ms
total 104.46 ms, 9.57 tok/s, 1341.45 GB/s, param 1320.92 GB/s
0
enqueue in  64.18 ms
total 109.32 ms, 9.15 tok/s, 1282.01 GB/s, param 1262.15 GB/s
 year
enqueue in  61.44 ms
total 106.53 ms, 9.39 tok/s, 1315.88 GB/s, param 1295.25 GB/s
 old
enqueue in  63.20 ms
total 108.86 ms, 9.19 tok/s, 1287.91 GB/s, param 1267.48 GB/s
 female
enqueue in  61.81 ms
total 107.25 ms, 9.32 tok/s, 1307.44 GB/s, param 1286.46 GB/s
.
output validated

@geohot
Copy link
Collaborator

geohot commented Jun 24, 2024

updated the bounty to $700 for ">20 tok/s running LLaMA 3 70B in FP16 on a tinybox" and locked it to you. you are welcome to do it in multiple PRs

Copy link
Contributor

Changes

Name                 Lines    Diff    Tokens/Line    Diff
-----------------  -------  ------  -------------  ------
tinygrad/multi.py      132      +0           22.1    +0.0


total lines changes: 0

@chaosagent chaosagent marked this pull request as ready for review June 26, 2024 02:58
@chenyuxyz
Copy link
Collaborator

this is fine to merge, it makes 7B on 4 GPU faster too (@wozeparrot you might want to copy this somehow in llama3). I cannot JITBEAM it directly though, did you save the ast the search separately somehow? maybe adding instructions in #3921

@chenyuxyz
Copy link
Collaborator

7B shard 4 with JITBEAM=2 (4 is similar), 55 tok/s -> 75 tok/s on red.

@chenyuxyz chenyuxyz merged commit 3604642 into tinygrad:master Jun 26, 2024
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants