Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma2 #709

Merged
merged 349 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
349 commits
Select commit Hold shift + click to select a range
d32e972
Update mapper.py
danielhanchen Jun 9, 2024
e121fa5
Update loader.py
danielhanchen Jun 9, 2024
5eaa10f
Update llama.py
danielhanchen Jun 9, 2024
f57d28d
Update tokenizer_utils.py
danielhanchen Jun 10, 2024
8937507
info
danielhanchen Jun 11, 2024
8982edb
edits
danielhanchen Jun 11, 2024
8904605
Create chat template
danielhanchen Jun 11, 2024
2a374c2
Fix tokenizer
danielhanchen Jun 12, 2024
d704b73
Merge branch 'main' into nightly
danielhanchen Jun 13, 2024
8176155
Update tokenizer_utils.py
danielhanchen Jun 13, 2024
21a99f1
fix case where gguf saving fails due to first_conversion dtype (#630)
chrehall68 Jun 13, 2024
dbf2dcf
Support revision parameter in FastLanguageModel.from_pretrained (#629)
chrehall68 Jun 13, 2024
9016171
clears any selected_adapters before calling internal_model.save_pretr…
neph1 Jun 13, 2024
0428920
Update __init__.py (#602)
xyangk Jun 13, 2024
9fdd847
Fixed unsloth/tokenizer_utils.py for chat training (#604)
Oseltamivir Jun 13, 2024
b5fc6aa
Add GGML saving option to Unsloth for easier Ollama model creation an…
mahiatlinux Jun 13, 2024
3fafbf7
docs: Add LoraConfig parameters documentation (#619)
sebdg Jun 13, 2024
273a871
llama.cpp failing (#371)
bet0x Jun 13, 2024
b312b3f
fix libcuda_dirs import for triton 3.0 (#227)
t-vi Jun 13, 2024
1601dca
Update save.py
danielhanchen Jun 13, 2024
26dc502
Update __init__.py
danielhanchen Jun 13, 2024
6a51657
Update fast_lora.py
danielhanchen Jun 13, 2024
4a8ba90
Update save.py
danielhanchen Jun 13, 2024
0abb5ba
Update save.py
danielhanchen Jun 13, 2024
b24dd05
Update save.py
danielhanchen Jun 13, 2024
48c6d6d
Update loader.py
danielhanchen Jun 13, 2024
e35f608
Update save.py
danielhanchen Jun 13, 2024
4822eae
Update save.py
danielhanchen Jun 13, 2024
7d847ed
quantize now llama-quantize
danielhanchen Jun 13, 2024
82f10cb
Update chat_templates.py
danielhanchen Jun 13, 2024
08424f0
Update loader.py
danielhanchen Jun 13, 2024
eb906d0
Update mapper.py
danielhanchen Jun 13, 2024
0a304ae
Update __init__.py
danielhanchen Jun 13, 2024
71edc42
embedding size
danielhanchen Jun 13, 2024
411b881
Merge branch 'main' into nightly
danielhanchen Jun 13, 2024
b74e321
Update qwen2.py
danielhanchen Jun 13, 2024
9c6d415
Merge branch 'main' into nightly
danielhanchen Jun 14, 2024
b82277f
docs
danielhanchen Jun 14, 2024
d98e45e
Update README.md
danielhanchen Jun 14, 2024
b6f0fdb
Update qwen2.py
danielhanchen Jun 14, 2024
6c031e4
README: Fix minor typo. (#559)
shaper Jun 14, 2024
2401dee
Update mistral.py
danielhanchen Jun 14, 2024
1b93d7e
Update qwen2.py
danielhanchen Jun 14, 2024
3581037
Update qwen2.py
danielhanchen Jun 14, 2024
b56b8b8
Update qwen2.py
danielhanchen Jun 14, 2024
fe8c064
Update llama.py
danielhanchen Jun 14, 2024
d8d332a
Update llama.py
danielhanchen Jun 14, 2024
cdb1dbb
Update llama.py
danielhanchen Jun 14, 2024
e8b3cf0
Update README.md
danielhanchen Jun 14, 2024
7e6f000
FastMistralModel
danielhanchen Jun 14, 2024
28995ab
Update mistral.py
danielhanchen Jun 14, 2024
515b1ae
Update mistral.py
danielhanchen Jun 14, 2024
7f28209
Update mistral.py
danielhanchen Jun 14, 2024
453cc48
Update mistral.py
danielhanchen Jun 14, 2024
6633d4a
Update mistral.py
danielhanchen Jun 14, 2024
e5bf125
Auto check rope scaling
danielhanchen Jun 14, 2024
d4f4bce
Merge branch 'main' into nightly
danielhanchen Jun 14, 2024
341565b
Update llama.py
danielhanchen Jun 14, 2024
dd3c6b1
Update llama.py
danielhanchen Jun 15, 2024
6d1ae23
Update llama.py
danielhanchen Jun 15, 2024
d855ef9
GPU support
danielhanchen Jun 15, 2024
da1fe76
Merge branch 'main' into nightly
danielhanchen Jun 15, 2024
6656446
Typo
danielhanchen Jun 15, 2024
9bd5fad
Update gemma.py
danielhanchen Jun 15, 2024
a3061b6
gpu
danielhanchen Jun 15, 2024
7e5155d
Merge branch 'main' into nightly
danielhanchen Jun 15, 2024
513bd4d
Multiple GGUF saving
danielhanchen Jun 15, 2024
fb54fbb
Update save.py
danielhanchen Jun 15, 2024
4cba3e2
Update save.py
danielhanchen Jun 15, 2024
979bb22
Merge branch 'main' into nightly
danielhanchen Jun 15, 2024
31811cf
check PEFT and base
danielhanchen Jun 15, 2024
a0232a7
Update llama.py
danielhanchen Jun 15, 2024
c4c4ff4
Update llama.py
danielhanchen Jun 15, 2024
80e82a2
Update llama.py
danielhanchen Jun 15, 2024
f62237d
Update llama.py
danielhanchen Jun 15, 2024
7f864dc
Update llama.py
danielhanchen Jun 15, 2024
4dda039
Update chat_templates.py
danielhanchen Jun 15, 2024
30605de
Fix breaking bug in save.py with interpreting quantization_method as …
ArcadaLabs-Jason Jun 16, 2024
1ba18ac
Merge branch 'main' into nightly
danielhanchen Jun 16, 2024
e2b2083
Revert "Fix breaking bug in save.py with interpreting quantization_me…
danielhanchen Jun 16, 2024
aeda849
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Jun 16, 2024
0938ab8
Revert "Revert "Fix breaking bug in save.py with interpreting quantiz…
danielhanchen Jun 16, 2024
46517a9
Merge branch 'main' into nightly
danielhanchen Jun 16, 2024
99f0397
Update llama.py
danielhanchen Jun 16, 2024
b824627
peft
danielhanchen Jun 16, 2024
e4f2263
patch
danielhanchen Jun 16, 2024
7aadceb
Update loader.py
danielhanchen Jun 16, 2024
17a9fb3
retrain
danielhanchen Jun 16, 2024
25cb17e
Update llama.py
danielhanchen Jun 16, 2024
3fba0f5
Update llama.py
danielhanchen Jun 16, 2024
798aa1e
Update llama.py
danielhanchen Jun 16, 2024
c6142d0
Update llama.py
danielhanchen Jun 16, 2024
7236dfc
Update llama.py
danielhanchen Jun 16, 2024
37f9abd
Update llama.py
danielhanchen Jun 16, 2024
7618197
Update llama.py
danielhanchen Jun 16, 2024
b4907f3
Update llama.py
danielhanchen Jun 16, 2024
2714a8b
Update llama.py
danielhanchen Jun 16, 2024
771a0d0
Update llama.py
danielhanchen Jun 16, 2024
af88eda
Merge branch 'main' into nightly
danielhanchen Jun 17, 2024
92c7d58
offload
danielhanchen Jun 17, 2024
2eacd2d
Update llama.py
danielhanchen Jun 17, 2024
b957061
Create a starter script for command-line training to integrate in ML …
sebdg Jun 18, 2024
4dba6c5
Update chat_templates.py
danielhanchen Jun 18, 2024
f2e4b83
Ollama
danielhanchen Jun 18, 2024
a5367a1
Update chat_templates.py
danielhanchen Jun 18, 2024
85062ff
Update chat_templates.py
danielhanchen Jun 18, 2024
12bd3cf
Update chat_templates.py
danielhanchen Jun 18, 2024
4417417
Update chat_templates.py
danielhanchen Jun 18, 2024
be02d97
Update chat_templates.py
danielhanchen Jun 18, 2024
0a1ee7a
Update chat_templates.py
danielhanchen Jun 18, 2024
7195152
Update chat_templates.py
danielhanchen Jun 18, 2024
676b20b
Update chat_templates.py
danielhanchen Jun 18, 2024
89b7807
Update chat_templates.py
danielhanchen Jun 18, 2024
563afa9
Update chat_templates.py
danielhanchen Jun 18, 2024
7bcaa23
Merge branch 'main' into nightly
danielhanchen Jun 19, 2024
bbfa77c
Ollama
danielhanchen Jun 19, 2024
9a88ac1
Update chat_templates.py
danielhanchen Jun 19, 2024
89e3027
ollama
danielhanchen Jun 19, 2024
412a7c9
Update mapper.py
danielhanchen Jun 19, 2024
44039e5
Update chat_templates.py
danielhanchen Jun 19, 2024
fc9ec40
Update save.py
danielhanchen Jun 19, 2024
7be6d88
Update save.py
danielhanchen Jun 19, 2024
5530710
Update save.py
danielhanchen Jun 19, 2024
d5b8d41
Update save.py
danielhanchen Jun 19, 2024
d1b3fac
Update save.py
danielhanchen Jun 19, 2024
99b68ff
Update save.py
danielhanchen Jun 19, 2024
a25dea9
Update save.py
danielhanchen Jun 19, 2024
bee81d1
Merge branch 'main' into nightly
danielhanchen Jun 20, 2024
e544c97
Update chat_templates.py
danielhanchen Jun 20, 2024
3f27de2
Update chat_templates.py
danielhanchen Jun 20, 2024
34c71b5
Update chat_templates.py
danielhanchen Jun 20, 2024
ae2924f
Update chat_templates.py
danielhanchen Jun 20, 2024
15e4e10
Merge branch 'main' into nightly
danielhanchen Jun 20, 2024
a5cbf3e
Update llama.py
danielhanchen Jun 20, 2024
b5888ba
Fixes
danielhanchen Jun 20, 2024
c37132b
Merge branch 'main' into nightly
danielhanchen Jun 21, 2024
f1076e8
clearer messages
danielhanchen Jun 21, 2024
ec26e40
Update tokenizer_utils.py
danielhanchen Jun 21, 2024
0e4189d
Update tokenizer_utils.py
danielhanchen Jun 21, 2024
0bc45bd
Update llama.py
danielhanchen Jun 21, 2024
586d72e
Update llama.py
danielhanchen Jun 21, 2024
0a52e24
Update llama.py
danielhanchen Jun 21, 2024
cf31aeb
log
danielhanchen Jun 21, 2024
6956277
Update __init__.py
danielhanchen Jun 21, 2024
ffe5eb4
Update llama.py
danielhanchen Jun 21, 2024
21b544a
Update __init__.py
danielhanchen Jun 21, 2024
a2a86b3
Merge branch 'main' into nightly
danielhanchen Jun 21, 2024
4fa816d
Create Merge.png
danielhanchen Jun 21, 2024
9b19f35
Create ollama.png
danielhanchen Jun 21, 2024
f331863
Gemma2
danielhanchen Jun 30, 2024
3510795
Update llama.py
danielhanchen Jun 30, 2024
80c32c7
Update loader.py
danielhanchen Jun 30, 2024
f9fdfba
Update pyproject.toml
danielhanchen Jun 30, 2024
e8ab057
Update pyproject.toml
danielhanchen Jun 30, 2024
0fb680a
Update llama.py
danielhanchen Jun 30, 2024
5ad42d6
Update llama.py
danielhanchen Jun 30, 2024
c0c9a67
Update llama.py
danielhanchen Jun 30, 2024
71f6afe
Update llama.py
danielhanchen Jun 30, 2024
8448a77
Update _utils.py
danielhanchen Jun 30, 2024
15b751d
Revert Gemma2
danielhanchen Jun 30, 2024
7c5d0ef
Update gemma2.py
danielhanchen Jun 30, 2024
3a6573f
Update gemma2.py
danielhanchen Jun 30, 2024
fae7957
Update gemma2.py
danielhanchen Jun 30, 2024
22160d6
Update gemma2.py
danielhanchen Jun 30, 2024
f2777ed
Update gemma2.py
danielhanchen Jun 30, 2024
73148cd
Update gemma2.py
danielhanchen Jun 30, 2024
876d9b6
Update gemma2.py
danielhanchen Jun 30, 2024
9649611
Update gemma2.py
danielhanchen Jun 30, 2024
b2e6e3c
Update rms_layernorm.py
danielhanchen Jun 30, 2024
5278ab1
Update gemma2.py
danielhanchen Jul 2, 2024
a76e5ec
logit softcapping
danielhanchen Jul 2, 2024
5511b76
Update cross_entropy_loss.py
danielhanchen Jul 2, 2024
e0f2da4
Update llama.py
danielhanchen Jul 2, 2024
648c985
Update llama.py
danielhanchen Jul 2, 2024
bba2ac9
Update gemma2.py
danielhanchen Jul 2, 2024
95c934d
Update gemma2.py
danielhanchen Jul 2, 2024
66f35fd
Update cross_entropy_loss.py
danielhanchen Jul 2, 2024
acf9771
Update llama.py
danielhanchen Jul 2, 2024
b6d83ef
Update llama.py
danielhanchen Jul 2, 2024
436b2a7
Update cross_entropy_loss.py
danielhanchen Jul 2, 2024
0271a7f
Update cross_entropy_loss.py
danielhanchen Jul 2, 2024
35d8200
Update llama.py
danielhanchen Jul 2, 2024
80a3f61
Update cross_entropy_loss.py
danielhanchen Jul 2, 2024
a46c839
Update cross_entropy_loss.py
danielhanchen Jul 2, 2024
47b0c2a
Update gemma2.py
danielhanchen Jul 2, 2024
6b8a1ae
Update gemma2.py
danielhanchen Jul 2, 2024
d834223
Update gemma2.py
danielhanchen Jul 2, 2024
69d6f53
Update gemma2.py
danielhanchen Jul 2, 2024
03af2ee
Update gemma2.py
danielhanchen Jul 2, 2024
03ded35
Update gemma2.py
danielhanchen Jul 2, 2024
f16479d
Update gemma2.py
danielhanchen Jul 2, 2024
af98ffb
Update gemma2.py
danielhanchen Jul 2, 2024
41f38c8
Update gemma2.py
danielhanchen Jul 2, 2024
0244e2f
Update gemma2.py
danielhanchen Jul 2, 2024
0c08818
Update llama.py
danielhanchen Jul 2, 2024
5b91c90
Update gemma2.py
danielhanchen Jul 2, 2024
96be4bb
Update llama.py
danielhanchen Jul 2, 2024
f2388f2
Update llama.py
danielhanchen Jul 2, 2024
74d9301
Update gemma2.py
danielhanchen Jul 2, 2024
980937b
Update gemma2.py
danielhanchen Jul 2, 2024
8984422
Update llama.py
danielhanchen Jul 2, 2024
4b5a242
Update gemma2.py
danielhanchen Jul 2, 2024
673f546
Update gemma2.py
danielhanchen Jul 2, 2024
2fb91b1
Update gemma2.py
danielhanchen Jul 2, 2024
320d92c
Update gemma2.py
danielhanchen Jul 2, 2024
b27c010
Update gemma2.py
danielhanchen Jul 2, 2024
e6dcc39
Update gemma2.py
danielhanchen Jul 2, 2024
79d73b4
Update gemma2.py
danielhanchen Jul 2, 2024
d724f83
Update gemma2.py
danielhanchen Jul 2, 2024
f2cf1e9
Update gemma2.py
danielhanchen Jul 2, 2024
a5dac5c
Update gemma2.py
danielhanchen Jul 2, 2024
cc7692d
Update gemma2.py
danielhanchen Jul 2, 2024
32da093
Update gemma2.py
danielhanchen Jul 2, 2024
f859faf
Update gemma2.py
danielhanchen Jul 2, 2024
9246a8b
Update gemma2.py
danielhanchen Jul 2, 2024
a210034
Update gemma2.py
danielhanchen Jul 2, 2024
b7d4338
Update gemma2.py
danielhanchen Jul 2, 2024
11036fe
Update gemma2.py
danielhanchen Jul 2, 2024
4b3d589
Update gemma2.py
danielhanchen Jul 2, 2024
4bfbdd5
Update _utils.py
danielhanchen Jul 2, 2024
3d0a5bc
Update _utils.py
danielhanchen Jul 2, 2024
478ef22
Update gemma2.py
danielhanchen Jul 2, 2024
2847d6d
compile flags
danielhanchen Jul 2, 2024
af2630d
Update _utils.py
danielhanchen Jul 2, 2024
4b5f8bd
Update _utils.py
danielhanchen Jul 2, 2024
1f3cee6
Update _utils.py
danielhanchen Jul 2, 2024
cf68600
Update _utils.py
danielhanchen Jul 2, 2024
edab699
Update _utils.py
danielhanchen Jul 2, 2024
546ddfb
Update _utils.py
danielhanchen Jul 2, 2024
97301fa
Update _utils.py
danielhanchen Jul 2, 2024
4516437
Update _utils.py
danielhanchen Jul 2, 2024
40926e0
Update _utils.py
danielhanchen Jul 2, 2024
d4616d8
Update gemma2.py
danielhanchen Jul 3, 2024
58f016a
Update gemma2.py
danielhanchen Jul 3, 2024
7cea034
fixes
danielhanchen Jul 3, 2024
f468976
Update _utils.py
danielhanchen Jul 3, 2024
c266d6b
Fix generation
danielhanchen Jul 3, 2024
30d3e7c
Update llama.py
danielhanchen Jul 3, 2024
9b2beba
Update llama.py
danielhanchen Jul 3, 2024
8526c3e
Update _utils.py
danielhanchen Jul 3, 2024
647f894
Update _utils.py
danielhanchen Jul 3, 2024
018d632
Update _utils.py
danielhanchen Jul 3, 2024
f8213cf
pad token
danielhanchen Jul 3, 2024
5e5f05d
Update gemma2.py
danielhanchen Jul 3, 2024
62a66d6
pad token
danielhanchen Jul 3, 2024
e67f335
Update _utils.py
danielhanchen Jul 3, 2024
505a017
Update llama.py
danielhanchen Jul 3, 2024
e46dab7
Update gemma2.py
danielhanchen Jul 3, 2024
9d83ced
edit warning
danielhanchen Jul 3, 2024
90d3b65
Update tokenizer_utils.py
danielhanchen Jul 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added images/Merge.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/ollama.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ exclude = ["images*"]
[project.optional-dependencies]
huggingface = [
"tyro",
"transformers>=4.38.2",
"transformers>=4.42.3",
"datasets>=2.16.0",
"sentencepiece>=0.2.0",
"tqdm",
Expand Down Expand Up @@ -185,9 +185,9 @@ colab-ampere-torch220 = [
]
colab-new = [
"tyro",
"transformers>=4.38.2",
"transformers>=4.42.3",
"datasets>=2.16.0",
"sentencepiece",
"sentencepiece>=0.2.0",
"tqdm",
"psutil",
"wheel>=0.42.0",
Expand Down
102 changes: 73 additions & 29 deletions unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
from transformers.models.llama.modeling_llama import logger


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Expand Down Expand Up @@ -58,29 +61,38 @@ def _cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))

if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = logsumexp - x
x = tl.load(logits_ptr + label_idx)
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
loss = logsumexp - x.to(tl.float32)
else:
loss = 0.0
tl.store(logsumexp_ptr, logsumexp)
tl.store(loss_ptr, loss)
pass


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
256K vocab divided in 4 chunks
Expand Down Expand Up @@ -117,7 +129,11 @@ def _chunked_cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))

Expand All @@ -126,7 +142,9 @@ def _chunked_cross_entropy_forward(
# Do the -x separately
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = -1.0 * x
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
loss = -1.0 * x.to(tl.float32)
else:
loss = 0.0
tl.store(loss_ptr, loss)
Expand All @@ -135,14 +153,17 @@ def _chunked_cross_entropy_forward(
pass


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_backward(
logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
Expand Down Expand Up @@ -173,15 +194,27 @@ def _cross_entropy_backward(
else:
dloss = 0.0

x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
partial = tl.math.tanh(x / SOFTCAP)
x = SOFTCAP * partial
pass

logsumexp = tl.load(logsumexp_ptr + row_idx)
y = tl.exp(x - logsumexp)
y = tl.exp(x.to(tl.float32) - logsumexp)
y = tl.where(
col_offsets == label_idx,
y - 1.0, # exp(x - logsumexp) - 1
y, # exp(x - logsumexp)
)

if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
y = y * (1.0 - partial*partial)
pass

# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
pass
Expand All @@ -191,40 +224,46 @@ def _cross_entropy_backward(

class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels):
def forward(ctx, logits, labels, logit_softcapping = 0):
n_rows, vocab_size = logits.shape

div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks = div + (mod != 0)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")

DO_SOFTCAPPING = (logit_softcapping != 0)

if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")

_cross_entropy_forward[(n_rows,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda")
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")

_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
num_warps = 32,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = 32,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
Expand All @@ -234,6 +273,8 @@ def forward(ctx, logits, labels):
pass

ctx.save_for_backward(logits, logsumexp, labels)
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
ctx.logit_softcapping = logit_softcapping
return losses
pass

Expand All @@ -251,16 +292,18 @@ def backward(ctx, dlosses):
dlosses, dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = 8,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
num_warps = 8,
)
return logits, None, None,
pass
pass


def fast_cross_entropy_loss(logits, labels):
def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
Expand All @@ -274,6 +317,7 @@ def fast_cross_entropy_loss(logits, labels):
loss = Fast_CrossEntropyLoss.apply(
logits.view(batch*seq_len, d),
labels.view(-1),
logit_softcapping,
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
Expand Down
4 changes: 2 additions & 2 deletions unsloth/kernels/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
def geglu_exact_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
Expand Down Expand Up @@ -133,7 +133,7 @@ def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
def geglu_approx_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
Expand Down
6 changes: 3 additions & 3 deletions unsloth/kernels/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _gemma_rms_layernorm_forward(
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)

row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * (W_row + 1.0)
Expand All @@ -137,8 +137,8 @@ def forward(ctx, X, W, eps, gemma = False):
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")

fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
fx[(n_rows,)](
Expand Down
2 changes: 1 addition & 1 deletion unsloth/kernels/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda")
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
return h
Expand Down
8 changes: 4 additions & 4 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def fast_dequantize(W, quant_state = None, out = None):

# Create weight matrix
if out is None:
out = torch.empty(shape, dtype = dtype, device = "cuda")
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
else:
assert(out.shape == shape)
assert(out.dtype == dtype)

# NF4 dequantization of statistics
n_elements_absmax = absmax.numel()
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda")
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")

# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
Expand Down Expand Up @@ -161,7 +161,7 @@ def fast_gemv(X, W, quant_state, out = None):
bout = shape[0]

if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda")
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
# else:
# assert(out.shape == (1, 1, bout,))
# pass
Expand All @@ -179,7 +179,7 @@ def fast_gemv(X, W, quant_state, out = None):
ldb = ctypes.c_int32(ldb)
ldc = ctypes.c_int32(ldc)

df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda")
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
Expand Down
Loading